diff options
1511 files changed, 88238 insertions, 33779 deletions
@@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# RBE requires a strong hash function, such as SHA256. +startup --host_jvm_args=-Dbazel.DigestFunction=SHA256 + # Build with C++17. build --cxxopt=-std=c++17 @@ -20,13 +23,25 @@ build --stamp --workspace_status_command tools/workspace_status.sh # Enable remote execution so actions are performed on the remote systems. build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:remote --bes_backend=buildeventservice.googleapis.com +build:remote --bes_results_url="https://source.cloud.google.com/results/invocations" +build:remote --bes_timeout=600s build:remote --project_id=gvisor-rbe build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance +build:remote3 --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:remote3 --project_id=gvisor-rbe +build:remote3 --bes_backend=buildeventservice.googleapis.com +build:remote3 --bes_results_url="https://source.cloud.google.com/results/invocations" +build:remote3 --bes_timeout=600s +build:remote3 --remote_instance_name=projects/gvisor-rbe/instances/default_instance + # Enable authentication. This will pick up application default credentials by # default. You can use --google_credentials=some_file.json to use a service # account credential instead. build:remote --google_default_credentials=true build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" +build:remote3 --google_default_credentials=true +build:remote3 --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" # Add a custom platform and toolchain that builds in a privileged docker # container, which is required by our syscall tests. @@ -35,10 +50,15 @@ build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-defa build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604 build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604 build:remote --crosstool_top=@rbe_default//cc:toolchain -build:remote --jobs=50 +build:remote --jobs=100 build:remote --remote_timeout=3600 -# RBE requires a strong hash function, such as SHA256. -startup --host_jvm_args=-Dbazel.DigestFunction=SHA256 +build:remote3 --host_platform=//tools/bazeldefs:rbe_ubuntu1604_bazel3 +build:remote3 --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default_bazel3 +build:remote3 --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3 +build:remote3 --platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3 +build:remote3 --crosstool_top=@rbe_default//cc:toolchain +build:remote3 --jobs=100 +build:remote3 --remote_timeout=3600 # Set flags for uploading to BES in order to view results in the Bazel Build # Results UI. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 49a1ba697..50d187633 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -23,7 +23,7 @@ reproduced with software that is publicly available. Please include the following details of your environment: -* `runsc -v` +* `runsc -version` * `docker version` or `docker info` (if available) * `kubectl version` and `kubectl get nodes` (if using Kubernetes) * `uname -a` diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cf782a580..e28e46352 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,9 +3,11 @@ on: push: branches: - master + - feature/** pull_request: branches: - master + - feature/** jobs: default: @@ -19,3 +21,8 @@ jobs: restore-keys: | ${{ runner.os }}-bazel- - run: make + - run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..." + - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ nogo" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 10c86f5cd..3a6a592d1 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -6,11 +6,19 @@ on: pull_request: branches: - master + - feature/** jobs: generate: runs-on: ubuntu-latest steps: + - id: setup + run: | + if ! [[ -z "${{ secrets.GO_TOKEN }}" ]]; then + echo ::set-output name=has_token::true + else + echo ::set-output name=has_token::false + fi - run: | jq -nc '{"state": "pending", "context": "go tests"}' | \ curl -sL -X POST -d @- \ @@ -19,12 +27,12 @@ jobs: "${{ github.event.pull_request.statuses_url }}" if: github.event_name == 'pull_request' - uses: actions/checkout@v2 - if: github.event_name == 'push' + if: github.event_name == 'push' && steps.setup.outputs.has_token == 'true' with: fetch-depth: 0 token: '${{ secrets.GO_TOKEN }}' - uses: actions/checkout@v2 - if: github.event_name == 'pull_request' + if: github.event_name == 'pull_request' || steps.setup.outputs.has_token != 'true' with: fetch-depth: 0 - uses: actions/setup-go@v2 @@ -42,7 +50,14 @@ jobs: key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }} restore-keys: | ${{ runner.os }}-bazel- - - run: make build TARGETS="//:gopath" + # Create gopath to merge the changes. The first execution will create + # symlinks to the cache, e.g. bazel-bin. Once the cache is setup, delete + # old gopath files that may exist from previous runs (and could contain + # files that are now deleted). Then run gopath again for good. + - run: | + make build TARGETS="//:gopath" + rm -rf bazel-bin/gopath + make build TARGETS="//:gopath" - run: tools/go_branch.sh - run: git checkout go && git clean -f - run: go build ./... diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml index 5e0254111..c53185620 100644 --- a/.github/workflows/issue_reviver.yml +++ b/.github/workflows/issue_reviver.yml @@ -4,11 +4,13 @@ on: - cron: '0 0 * * *' jobs: - label: + issue_reviver: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - run: make run TARGETS="//tools/issue_reviver" + if: github.repository == 'google/gvisor' + - run: make run TARGETS="//tools/github" ARGS="revive" + if: github.repository == 'google/gvisor' env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.travis.yml b/.travis.yml index 9d3141f38..2d9fa80a1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,15 +30,17 @@ services: - docker jobs: include: - - os: linux - arch: amd64 + # AMD64 builds are tested on kokoro, so don't run them in travis to save + # capacity for arm64 builds. + # - os: linux + # arch: amd64 - os: linux arch: arm64 script: # On arm64, we need to create our own pipes for stderr and stdout, # otherwise we will not be able to open /dev/stderr. This is probably # due to AppArmor rules. - - bash -xeo pipefail -c 'uname -a && make smoke-test 2>&1 | cat' + - bash -xeo pipefail -c 'uname -a && make smoke-tests 2>&1 | cat' branches: except: # Skip copybara branches. @@ -30,7 +30,7 @@ doc( permalink = "/community/governance/", subcategory = "Community", visibility = ["//website:__pkg__"], - weight = "91", + weight = "20", ) doc( @@ -57,6 +57,12 @@ build_test( "//test/e2e:integration_test", "//test/image:image_test", "//test/root:root_test", + "//test/benchmarks/base:base_test", + "//test/benchmarks/database:database_test", + "//test/benchmarks/fs:fs_test", + "//test/benchmarks/media:media_test", + "//test/benchmarks/ml:ml_test", + "//test/benchmarks/network:network_test", ], ) @@ -69,7 +75,10 @@ go_path( name = "gopath", mode = "link", deps = [ + # Main binary. "//runsc", + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", # Packages that are not dependencies of //runsc. "//pkg/sentry/kernel/memevent", @@ -200,3 +200,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +------------------ + +Some files carry the following license, noted at the top of each file: + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE.
\ No newline at end of file @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Helpful pretty-printer. +MAKEBANNER := \033[1;34mmake\033[0m +submake = echo -e '$(MAKEBANNER) $1' >&2; $(MAKE) $1 + # Described below. OPTIONS := STARTUP_OPTIONS := @@ -85,7 +89,7 @@ endif ## define images $(1)-%: ## Image tool: $(1) a given image (also may use 'all-images'). - @$(MAKE) -C images $$@ + @$(call submake,-C images $$@) endef rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'. $(eval $(call images,rebuild)) @@ -96,7 +100,7 @@ $(eval $(call images,push)) load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'. $(eval $(call images,load)) list-images: ## List all available images. - @$(MAKE) -C images $$@ + @$(call submake, -C images $$@) ## ## Canonical build and test targets. @@ -106,21 +110,137 @@ list-images: ## List all available images. ## new subsystem or workflow, consider adding a new target here. ## runsc: ## Builds the runsc binary. - @$(MAKE) build TARGETS="//runsc" + @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc") .PHONY: runsc -smoke-test: ## Runs a simple smoke test after build runsc. - @$(MAKE) run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true" +debian: ## Builds the debian packages. + @$(call submake,build OPTIONS="-c opt" TARGETS="//debian:debian") +.PHONY: debian + +smoke-tests: ## Runs a simple smoke test after build runsc. + @$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true") .PHONY: smoke-tests -unit-tests: ## Runs all unit tests in pkg runsc and tools. - @$(MAKE) test OPTIONS="pkg/... runsc/... tools/..." +fuse-tests: + @$(call submake,test OPTIONS="--test_tag_filters fuse" TARGETS="test/fuse/...") +.PHONY: fuse-tests + +unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc. + @$(call submake,test TARGETS="pkg/... runsc/... tools/...") .PHONY: unit-tests -tests: ## Runs all local ptrace system call tests. - @$(MAKE) test OPTIONS="--test_tag_filters runsc_ptrace test/syscalls/..." +tests: ## Runs all unit tests and syscall tests. +tests: unit-tests + @$(call submake,test TARGETS="test/syscalls/...") .PHONY: tests +integration-tests: ## Run all standard integration tests. +integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests +integration-tests: do-tests kvm-tests containerd-test-1.3.4 +.PHONY: integration-tests + +network-tests: ## Run all networking integration tests. +network-tests: iptables-tests packetdrill-tests packetimpact-tests +.PHONY: network-tests + +# Standard integration targets. +INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test + +syscall-%-tests: + @$(call submake,test OPTIONS="--test_tag_filters runsc_$* test/syscalls/...") + +syscall-native-tests: + @$(call submake,test OPTIONS="--test_tag_filters native test/syscalls/...") +.PHONY: syscall-native-tests + +syscall-tests: ## Run all system call tests. +syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests +.PHONY: syscall-tests + +%-runtime-tests: load-runtimes_% + @$(call submake,install-test-runtime) + @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + +%-runtime-tests_vfs2: load-runtimes_% + @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") + @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + +do-tests: runsc + @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true") + @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true") + @$(call submake,sudo TARGETS="//runsc" ARGS="do true") +.PHONY: do-tests + +simple-tests: unit-tests # Compatibility target. +.PHONY: simple-tests + +docker-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="vfs1") + @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)") + @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") + @$(call submake,test-runtime RUNTIME="vfs2" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: docker-tests + +overlay-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="overlay" ARGS="--overlay") + @$(call submake,test-runtime RUNTIME="overlay" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: overlay-tests + +swgso-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false") + @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: swgso-tests + +hostnet-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="hostnet" ARGS="--network=host") + @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: hostnet-tests + +kvm-tests: load-basic-images + @(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm + @if ! [[ -w /dev/kvm ]]; then sudo chmod a+rw /dev/kvm; fi + @$(call submake,test TARGETS="//pkg/sentry/platform/kvm:kvm_test") + @$(call submake,install-test-runtime RUNTIME="kvm" ARGS="--platform=kvm") + @$(call submake,test-runtime RUNTIME="kvm" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: kvm-tests + +iptables-tests: load-iptables + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter + @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test") + @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw") + @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") +.PHONY: iptables-tests + +packetdrill-tests: load-packetdrill + @$(call submake,install-test-runtime RUNTIME="packetdrill") + @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')") +.PHONY: packetdrill-tests + +packetimpact-tests: load-packetimpact + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter + @$(call submake,install-test-runtime RUNTIME="packetimpact") + @$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetimpact, tests(//...))')") +.PHONY: packetimpact-tests + +# Specific containerd version tests. +containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu + @$(call submake,install-test-runtime RUNTIME="root") + @CONTAINERD_VERSION=$* $(MAKE) sudo TARGETS="tools/installers:containerd" + @$(MAKE) sudo TARGETS="tools/installers:shim" + @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="--runtime=root -test.v" + +# Note that we can't run containerd-test-1.1.8 tests here. +# +# Containerd 1.1.8 should work, but because of a bug in loading images locally +# (https://github.com/kubernetes-sigs/cri-tools/issues/421), we are unable to +# actually drive the tests. The v1 API is tested exclusively through 1.2.13. +containerd-tests: ## Runs all supported containerd version tests. +containerd-tests: containerd-test-1.2.13 +containerd-tests: containerd-test-1.3.4 +containerd-tests: containerd-test-1.4.0-beta.0 + ## ## Website & documentation helpers. ## @@ -138,7 +258,7 @@ WEBSITE_PROJECT := gvisordev WEBSITE_REGION := us-central1 website-build: load-jekyll ## Build the site image locally. - @$(MAKE) run TARGETS="//website:website" + @$(call submake,run TARGETS="//website:website") .PHONY: website-build website-server: website-build ## Run a local server for development. @@ -151,7 +271,7 @@ website-push: website-build ## Push a new image and update the service. website-deploy: website-push ## Deploy a new version of the website. @gcloud run deploy $(WEBSITE_SERVICE) --platform=managed --region=$(WEBSITE_REGION) --project=$(WEBSITE_PROJECT) --image=$(WEBSITE_IMAGE) -.PHONY: website-push +.PHONY: website-deploy ## ## Repository builders. @@ -182,15 +302,17 @@ $(RELEASE_KEY): echo Name-Email: test@example.com >> $$C && \ echo Expire-Date: 0 >> $$C && \ echo %commit >> $$C && \ - gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --keyring $$T --no-tty --gen-key $$C && \ - gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --keyring $$T --secret-keyring $$T > $@; \ + gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --secret-keyring $$T --no-tty --gen-key $$C && \ + gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --secret-keyring $$T > $@; \ rc=$$?; rm -f $$T $$C; exit $$rc release: $(RELEASE_KEY) ## Builds a release. @mkdir -p $(RELEASE_ROOT) @T=$$(mktemp -d /tmp/release.XXXXXX); \ - $(MAKE) copy TARGETS="runsc" DESTINATION=$$T && \ - $(MAKE) copy TARGETS="runsc:runsc-debian" DESTINATION=$$T && \ + $(call submake,copy TARGETS="//runsc:runsc" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//shim/v1:gvisor-containerd-shim" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//shim/v2:containerd-shim-runsc-v1" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//debian:debian" DESTINATION=$$T) && \ NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \ rc=$$?; rm -rf $$T; exit $$rc .PHONY: release @@ -213,42 +335,52 @@ tag: ## Creates and pushes a release tag. ## ifeq (,$(BRANCH_NAME)) RUNTIME := runsc -RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/runsc +RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) else RUNTIME := $(BRANCH_NAME) -RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(BRANCH_NAME) +RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) endif RUNTIME_BIN := $(RUNTIME_DIR)/runsc RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% dev: ## Installs a set of local runtimes. Requires sudo. - @$(MAKE) refresh ARGS="--net-raw" - @$(MAKE) configure RUNTIME="$(RUNTIME)" ARGS="--net-raw" - @$(MAKE) configure RUNTIME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets" - @$(MAKE) configure RUNTIME="$(RUNTIME)-p" ARGS="--net-raw --profile" - @$(MAKE) configure RUNTIME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2" + @$(call submake,refresh ARGS="--net-raw") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2") @sudo systemctl restart docker .PHONY: dev -refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'test-install' first. +refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'install-test-runtime' first. @mkdir -p "$(RUNTIME_DIR)" - @$(MAKE) copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)" && chmod 0755 "$(RUNTIME_BIN)" -.PHONY: install + @$(call submake,copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)") +.PHONY: refresh -test-install: ## Installs the runtime for testing. Requires sudo. - @$(MAKE) refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)" - @$(MAKE) configure +install-test-runtime: ## Installs the runtime for testing. Requires sudo. + @$(call submake,refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)") + @$(call submake,configure RUNTIME_NAME=runsc) + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)") @sudo systemctl restart docker -.PHONY: install-test + @if [[ -f /etc/docker/daemon.json ]]; then \ + sudo chmod 0755 /etc/docker && \ + sudo chmod 0644 /etc/docker/daemon.json; \ + fi +.PHONY: install-test-runtime -configure: ## Configures a single runtime. Requires sudo. Typically called from dev or test-install. - @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) - @echo "Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)" - @echo "Logs are in: $(RUNTIME_LOG_DIR)" +configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-test-runtime. + @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) + @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)" + @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)" @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)" .PHONY: configure test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided. - @$(MAKE) test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)" -.PHONY: runtime-test + @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)") +.PHONY: test-runtime + +nogo: ## Surfaces all nogo findings. + @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...") + @$(call submake,run TARGETS="//tools/github" ARGS="-path=$(BUILD_ROOT) -dry-run nogo") +.PHONY: nogo @@ -2,6 +2,7 @@  [](https://gitter.im/gvisor/community) +[](https://cs.opensource.google/gvisor/gvisor) ## What is gVisor? @@ -58,7 +59,7 @@ Make sure the following dependencies are installed: Build and install the `runsc` binary: -``` +```sh make runsc sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin ``` @@ -67,14 +68,14 @@ sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin To run standard test suites, you can use: -``` +```sh make unit-tests make tests ``` To run specific tests, you can specify the target: -``` +```sh make test TARGETS="//runsc:version_test" ``` @@ -84,12 +85,19 @@ This project uses [bazel][bazel] to build and manage dependencies. A synthetic `go` branch is maintained that is compatible with standard `go` tooling for convenience. -For example, to build `runsc` directly from this branch: +For example, to build and install `runsc` directly from this branch: -``` +```sh echo "module runsc" > go.mod GO111MODULE=on go get gvisor.dev/gvisor/runsc@go -CGO_ENABLED=0 GO111MODULE=on go install gvisor.dev/gvisor/runsc +CGO_ENABLED=0 GO111MODULE=on sudo -E go build -o /usr/local/bin/runsc gvisor.dev/gvisor/runsc +``` + +Subsequently, you can build and install the shim binaries for `containerd`: + +```sh +GO111MODULE=on sudo -E go build -o /usr/local/bin/gvisor-containerd-shim gvisor.dev/gvisor/shim/v1 +GO111MODULE=on sudo -E go build -o /usr/local/bin/containerd-shim-runsc-v1 gvisor.dev/gvisor/shim/v2 ``` Note that this branch is supported in a best effort capacity, and direct @@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # Bazel/starlark utilities. http_archive( name = "bazel_skylib", + sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", ], - sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", ) load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") @@ -42,6 +42,28 @@ http_archive( ], ) +http_archive( + name = "io_bazel_rules_go_bazel3", # To replace the above. + patch_args = ["-p1"], + patches = [ + "//tools/nogo:io_bazel_rules_go-visibility.patch", + ], + sha256 = "87f0fb9747854cb76a0a82430adccb6269f7d394237104a4523b51061c469171", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz", + ], +) + +http_archive( + name = "bazel_gazelle_bazel3", # To replace the above. + sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", + ], +) + load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") go_rules_dependencies() @@ -52,9 +74,6 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") gazelle_dependencies() -# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories" -# block below once 1876 is fixed. -# # The com_google_protobuf repository below would trigger downloading a older # version of org_golang_x_sys. If putting this repository statment in a place # after that of the com_google_protobuf, this statement will not work as @@ -94,26 +113,6 @@ rules_proto_dependencies() rules_proto_toolchains() -# Load python dependencies. -git_repository( - name = "rules_python", - commit = "abc4869e02fe9b3866942e89f07b7341f830e805", - remote = "https://github.com/bazelbuild/rules_python.git", - shallow_since = "1583341286 -0500", -) - -load("@rules_python//python:pip.bzl", "pip_import") - -pip_import( - name = "pydeps", - python_interpreter = "python3", - requirements = "//benchmarks:requirements.txt", -) - -load("@pydeps//:requirements.bzl", "pip_install") - -pip_install() - # Load bazel_toolchain to support Remote Build Execution. # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( @@ -126,6 +125,16 @@ http_archive( ], ) +http_archive( + name = "bazel_toolchains_bazel3", # To replace the above. + sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48", + strip_prefix = "bazel-toolchains-3.1.1", + urls = [ + "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz", + ], +) + # Creates a default toolchain config for RBE. load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig") @@ -159,12 +168,64 @@ load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") grpc_extra_deps() -# External repositories, in sorted order. +# System Call test dependencies. +http_archive( + name = "com_google_absl", + sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93", + strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e", + urls = [ + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", + ], +) + +http_archive( + name = "com_google_googletest", + sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912", + strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34", + urls = [ + "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", + "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", + ], +) + +http_archive( + name = "com_google_benchmark", + sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", + strip_prefix = "benchmark-1.5.0", + urls = [ + "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", + "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", + ], +) + +# External Go repositories. +# +# Unfortunately, gazelle will automatically parse go modules in the +# repositories and generate new go_repository stanzas. These may not respect +# pins that we have in go.mod or below. So order actually matters here. + +go_repository( + name = "com_github_sirupsen_logrus", + importpath = "github.com/sirupsen/logrus", + replace = "github.com/Sirupsen/logrus", + sum = "h1:cWjBmzJnL1sO88XdqJYmq7aiWClqXIQQMJ3Utgy1f+I=", + version = "v1.4.2", +) + +go_repository( + name = "com_github_containerd_containerd", + build_file_proto_mode = "disable", + importpath = "github.com/containerd/containerd", + sum = "h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=", + version = "v1.3.4", +) + go_repository( name = "com_github_cenkalti_backoff", importpath = "github.com/cenkalti/backoff", - sum = "h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=", - version = "v0.0.0-20190506075156-2146c9339422", + sum = "h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4=", + version = "v1.1.1-0.20190506075156-2146c9339422", ) go_repository( @@ -182,38 +243,31 @@ go_repository( ) go_repository( - name = "com_github_google_go-cmp", - importpath = "github.com/google/go-cmp", - sum = "h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=", - version = "v0.2.0", -) - -go_repository( name = "com_github_google_subcommands", importpath = "github.com/google/subcommands", - sum = "h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=", - version = "v0.0.0-20190508160503-636abe8753b8", + sum = "h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM=", + version = "v1.0.2-0.20190508160503-636abe8753b8", ) go_repository( name = "com_github_google_uuid", importpath = "github.com/google/uuid", - sum = "h1:rXQlD9GXkjA/PQZhmEaF/8Pj/sJfdZJK7GJG0gkS8I0=", - version = "v0.0.0-20171129191014-dec09d789f3d", + sum = "h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=", + version = "v1.0.0", ) go_repository( name = "com_github_kr_pretty", importpath = "github.com/kr/pretty", - sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=", - version = "v0.2.0", + sum = "h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=", + version = "v0.1.0", ) go_repository( name = "com_github_kr_pty", importpath = "github.com/kr/pty", - sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=", - version = "v1.1.1", + sum = "h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=", + version = "v1.1.4-0.20190131011033-7dc38fb350b1", ) go_repository( @@ -225,15 +279,9 @@ go_repository( go_repository( name = "com_github_mohae_deepcopy", - commit = "c48cc78d482608239f6c4c92a4abd87eb8761c90", importpath = "github.com/mohae/deepcopy", -) - -go_repository( - name = "com_github_opencontainers_runtime-spec", - importpath = "github.com/opencontainers/runtime-spec", - sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=", - version = "v0.1.2-0.20171211145439-b2d941ef6a78", + sum = "h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c=", + version = "v0.0.0-20170308212314-bb9b5e7adda9", ) go_repository( @@ -246,65 +294,58 @@ go_repository( go_repository( name = "com_github_vishvananda_netlink", importpath = "github.com/vishvananda/netlink", - sum = "h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=", - version = "v1.0.1-0.20190318003149-adb577d4a45e", -) - -go_repository( - name = "com_github_vishvananda_netns", - importpath = "github.com/vishvananda/netns", - sum = "h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=", - version = "v0.0.0-20171111001504-be1fbeda1936", + sum = "h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w=", + version = "v1.0.1-0.20190930145447-2ec5bdc52b86", ) go_repository( name = "org_golang_google_grpc", build_file_proto_mode = "disable", importpath = "google.golang.org/grpc", - sum = "h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=", - version = "v1.27.1", + sum = "h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=", + version = "v1.29.0", ) go_repository( name = "in_gopkg_check_v1", importpath = "gopkg.in/check.v1", - sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=", - version = "v1.0.0-20190902080502-41f04d3bba15", + sum = "h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=", + version = "v1.0.0-20180628173108-788fd7840127", ) go_repository( name = "org_golang_x_crypto", importpath = "golang.org/x/crypto", - sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=", - version = "v0.0.0-20190308221718-c2843e01d9a2", + sum = "h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=", + version = "v0.0.0-20200622213623-75b288015ac9", ) go_repository( name = "org_golang_x_mod", importpath = "golang.org/x/mod", - sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=", - version = "v0.2.1-0.20200224194123-e5e73c1b9c72", + sum = "h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=", + version = "v0.3.0", ) go_repository( name = "org_golang_x_net", importpath = "golang.org/x/net", - sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=", - version = "v0.0.0-20190620200207-3b0461eec859", + sum = "h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=", + version = "v0.0.0-20200625001655-4c5254603344", ) go_repository( name = "org_golang_x_sync", importpath = "golang.org/x/sync", - sum = "h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=", - version = "v0.0.0-20190423024810-112230192c58", + sum = "h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=", + version = "v0.0.0-20200625203802-6e8e738ad208", ) go_repository( name = "org_golang_x_text", importpath = "golang.org/x/text", - sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=", - version = "v0.3.0", + sum = "h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=", + version = "v0.3.2", ) go_repository( @@ -317,15 +358,15 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:Uglradbb4KfUWaYasZhlsDsGRwHHvRsHoNAEONef0W8=", - version = "v0.0.0-20200131233409-575de47986ce", + sum = "h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE=", + version = "v0.0.0-20200707200213-416e8f4faf8a", ) go_repository( name = "org_golang_x_xerrors", importpath = "golang.org/x/xerrors", - sum = "h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=", - version = "v0.0.0-20190717185122-a985d3407aa7", + sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=", + version = "v0.0.0-20191204190536-9bdfabe68543", ) go_repository( @@ -338,43 +379,106 @@ go_repository( go_repository( name = "com_github_golang_protobuf", importpath = "github.com/golang/protobuf", - sum = "h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=", - version = "v1.3.1", + sum = "h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=", + version = "v1.4.2", +) + +go_repository( + name = "org_golang_x_oauth2", + importpath = "golang.org/x/oauth2", + sum = "h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=", + version = "v0.0.0-20200107190931-bf48bf16ab8d", ) go_repository( - name = "com_github_google_go-github", - importpath = "github.com/google/go-github", - sum = "h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=", - version = "v17.0.0", + name = "com_github_docker_docker", + importpath = "github.com/docker/docker", + sum = "h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w=", + version = "v1.4.2-0.20191028175130-9e7d5ac5ea55", ) go_repository( - name = "org_golang_x_oauth2", - importpath = "golang.org/x/oauth2", - sum = "h1:pE8b58s1HRDMi8RDc79m0HISf9D4TzseP40cEA6IGfs=", - version = "v0.0.0-20191202225959-858c2ad4c8b6", + name = "com_github_docker_go_connections", + importpath = "github.com/docker/go-connections", + sum = "h1:3lOnM9cSzgGwx8VfK/NGOW5fLQ0GjIlCkaktF+n1M6o=", + version = "v0.3.0", ) go_repository( - name = "com_github_google_go-querystring", - importpath = "github.com/google/go-querystring", - sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=", + name = "com_github_pkg_errors", + importpath = "github.com/pkg/errors", + sum = "h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=", + version = "v0.9.1", +) + +go_repository( + name = "com_github_docker_go_units", + importpath = "github.com/docker/go-units", + sum = "h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=", + version = "v0.4.0", +) + +go_repository( + name = "com_github_opencontainers_go_digest", + importpath = "github.com/opencontainers/go-digest", + sum = "h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=", version = "v1.0.0", ) go_repository( - name = "com_google_cloud_go_bigquery", - importpath = "cloud.google.com/go/bigquery", - sum = "h1:K2NyuHRuv15ku6eUpe0DQk5ZykPMnSOnvuVf6IHcjaE=", - version = "v1.5.0", + name = "com_github_docker_distribution", + importpath = "github.com/docker/distribution", + sum = "h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE=", + version = "v2.7.1-0.20190205005809-0d3efadf0154+incompatible", ) go_repository( - name = "org_golang_google_api", - importpath = "google.golang.org/api", - sum = "h1:jz2KixHX7EcCPiQrySzPdnYT7DbINAypCqKZ1Z7GM40=", - version = "v0.20.0", + name = "com_github_davecgh_go_spew", + importpath = "github.com/davecgh/go-spew", + sum = "h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=", + version = "v1.1.1", +) + +go_repository( + name = "com_github_konsorten_go_windows_terminal_sequences", + importpath = "github.com/konsorten/go-windows-terminal-sequences", + sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=", + version = "v1.0.2", +) + +go_repository( + name = "com_github_pmezard_go_difflib", + importpath = "github.com/pmezard/go-difflib", + sum = "h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_stretchr_testify", + importpath = "github.com/stretchr/testify", + sum = "h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=", + version = "v1.4.0", +) + +go_repository( + name = "com_github_opencontainers_image_spec", + importpath = "github.com/opencontainers/image-spec", + sum = "h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=", + version = "v1.0.1", +) + +go_repository( + name = "com_github_microsoft_go_winio", + importpath = "github.com/Microsoft/go-winio", + sum = "h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA=", + version = "v0.4.15-0.20190919025122-fc70bd9a86b5", +) + +go_repository( + name = "com_github_stretchr_objx", + importpath = "github.com/stretchr/objx", + sum = "h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=", + version = "v0.1.1", ) go_repository( @@ -387,16 +491,457 @@ go_repository( go_repository( name = "org_uber_go_multierr", importpath = "go.uber.org/multierr", - sum = "h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=", - version = "v1.5.0", + sum = "h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=", + version = "v1.2.0", ) -# BigQuery Dependencies for Benchmarks go_repository( name = "com_google_cloud_go", importpath = "cloud.google.com/go", - sum = "h1:eoz/lYxKSL4CNAiaUJ0ZfD1J3bfMYbU5B3rwM1C1EIU=", - version = "v0.55.0", + sum = "h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo=", + version = "v0.52.1-0.20200122224058-0482b626c726", +) + +go_repository( + name = "io_opencensus_go", + importpath = "go.opencensus.io", + sum = "h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs=", + version = "v0.22.2", +) + +go_repository( + name = "co_honnef_go_tools", + importpath = "honnef.co/go/tools", + sum = "h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=", + version = "v0.0.1-2019.2.3", +) + +go_repository( + name = "com_github_burntsushi_toml", + importpath = "github.com/BurntSushi/toml", + sum = "h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=", + version = "v0.3.1", +) + +go_repository( + name = "com_github_census_instrumentation_opencensus_proto", + importpath = "github.com/census-instrumentation/opencensus-proto", + sum = "h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk=", + version = "v0.2.1", +) + +go_repository( + name = "com_github_client9_misspell", + importpath = "github.com/client9/misspell", + sum = "h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=", + version = "v0.3.4", +) + +go_repository( + name = "com_github_cncf_udpa_go", + importpath = "github.com/cncf/udpa/go", + sum = "h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU=", + version = "v0.0.0-20191209042840-269d4d468f6f", +) + +go_repository( + name = "com_github_containerd_cgroups", + build_file_proto_mode = "disable", + importpath = "github.com/containerd/cgroups", + sum = "h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=", + version = "v0.0.0-20181219155423-39b18af02c41", +) + +go_repository( + name = "com_github_containerd_console", + importpath = "github.com/containerd/console", + sum = "h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=", + version = "v0.0.0-20191206165004-02ecf6a7291e", +) + +go_repository( + name = "com_github_containerd_continuity", + importpath = "github.com/containerd/continuity", + sum = "h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=", + version = "v0.0.0-20200710164510-efbc4488d8fe", +) + +go_repository( + name = "com_github_containerd_fifo", + importpath = "github.com/containerd/fifo", + sum = "h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw=", + version = "v0.0.0-20191213151349-ff969a566b00", +) + +go_repository( + name = "com_github_containerd_go_runc", + importpath = "github.com/containerd/go-runc", + sum = "h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds=", + version = "v0.0.0-20200220073739-7016d3ce2328", +) + +go_repository( + name = "com_github_containerd_ttrpc", + importpath = "github.com/containerd/ttrpc", + sum = "h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc=", + version = "v0.0.0-20200121165050-0be804eadb15", +) + +go_repository( + name = "com_github_coreos_go_systemd", + importpath = "github.com/coreos/go-systemd", + sum = "h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU=", + version = "v0.0.0-20191104093116-d3cd4ed1dbcf", +) + +go_repository( + name = "com_github_docker_go_events", + importpath = "github.com/docker/go-events", + sum = "h1:+pKlWGMw7gf6bQ+oDZB4KHQFypsfjYlq/C4rfL7D3g8=", + version = "v0.0.0-20190806004212-e31b211e4f1c", +) + +go_repository( + name = "com_github_dustin_go_humanize", + importpath = "github.com/dustin/go-humanize", + sum = "h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=", + version = "v0.0.0-20171111073723-bb3d318650d4", +) + +go_repository( + name = "com_github_envoyproxy_go_control_plane", + importpath = "github.com/envoyproxy/go-control-plane", + sum = "h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E=", + version = "v0.9.4", +) + +go_repository( + name = "com_github_envoyproxy_protoc_gen_validate", + importpath = "github.com/envoyproxy/protoc-gen-validate", + sum = "h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_fsnotify_fsnotify", + importpath = "github.com/fsnotify/fsnotify", + sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=", + version = "v1.4.7", +) + +go_repository( + name = "com_github_godbus_dbus", + importpath = "github.com/godbus/dbus", + sum = "h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=", + version = "v0.0.0-20190422162347-ade71ed3457e", +) + +go_repository( + name = "com_github_gogo_googleapis", + importpath = "github.com/gogo/googleapis", + sum = "h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=", + version = "v1.4.0", +) + +go_repository( + name = "com_github_gogo_protobuf", + importpath = "github.com/gogo/protobuf", + sum = "h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=", + version = "v1.3.1", +) + +go_repository( + name = "com_github_golang_glog", + importpath = "github.com/golang/glog", + sum = "h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=", + version = "v0.0.0-20160126235308-23def4e6c14b", +) + +go_repository( + name = "com_github_google_go_cmp", + importpath = "github.com/google/go-cmp", + sum = "h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=", + version = "v0.5.0", +) + +go_repository( + name = "com_github_google_go_github_v28", + importpath = "github.com/google/go-github/v28", + sum = "h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=", + version = "v28.1.2-0.20191108005307-e555eab49ce8", +) + +go_repository( + name = "com_github_google_go_querystring", + importpath = "github.com/google/go-querystring", + sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_golang_lru", + importpath = "github.com/hashicorp/golang-lru", + sum = "h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU=", + version = "v0.5.1", +) + +go_repository( + name = "com_github_hpcloud_tail", + importpath = "github.com/hpcloud/tail", + sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_inconshreveable_mousetrap", + importpath = "github.com/inconshreveable/mousetrap", + sum = "h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_kisielk_errcheck", + importpath = "github.com/kisielk/errcheck", + sum = "h1:reN85Pxc5larApoH1keMBiu2GWtPqXQ1nc9gx+jOU+E=", + version = "v1.2.0", +) + +go_repository( + name = "com_github_kisielk_gotool", + importpath = "github.com/kisielk/gotool", + sum = "h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_microsoft_hcsshim", + importpath = "github.com/Microsoft/hcsshim", + sum = "h1:ZfF0+zZeYdzMIVMZHKtDKJvLHj76XCuVae/jNkjj0IA=", + version = "v0.8.6", +) + +go_repository( + name = "com_github_onsi_ginkgo", + importpath = "github.com/onsi/ginkgo", + sum = "h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=", + version = "v1.10.1", +) + +go_repository( + name = "com_github_onsi_gomega", + importpath = "github.com/onsi/gomega", + sum = "h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=", + version = "v1.7.0", +) + +go_repository( + name = "com_github_opencontainers_runc", + importpath = "github.com/opencontainers/runc", + sum = "h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y=", + version = "v0.1.1", +) + +go_repository( + name = "com_github_opencontainers_runtime_spec", + importpath = "github.com/opencontainers/runtime-spec", + sum = "h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8=", + version = "v1.0.2-0.20181111125026-1722abf79c2f", +) + +go_repository( + name = "com_github_pborman_uuid", + importpath = "github.com/pborman/uuid", + sum = "h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=", + version = "v1.2.0", +) + +go_repository( + name = "com_github_prometheus_client_model", + importpath = "github.com/prometheus/client_model", + sum = "h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=", + version = "v0.0.0-20190812154241-14fe0d1b01d4", +) + +go_repository( + name = "com_github_prometheus_procfs", + importpath = "github.com/prometheus/procfs", + sum = "h1:Lo6mRUjdS99f3zxYOUalftWHUoOGaDRqFk1+j0Q57/I=", + version = "v0.0.0-20190522114515-bc1a522cf7b1", +) + +go_repository( + name = "com_github_spf13_cobra", + importpath = "github.com/spf13/cobra", + sum = "h1:GQkkv3XSnxhAMjdq2wLfEnptEVr+2BNvmHizILHn+d4=", + version = "v0.0.2-0.20171109065643-2da4a54c5cee", +) + +go_repository( + name = "com_github_spf13_pflag", + importpath = "github.com/spf13/pflag", + sum = "h1:j8jxLbQ0+T1DFggy6XoGvyUnrJWPR/JybflPvu5rwS4=", + version = "v1.0.1-0.20171106142849-4c012f6dcd95", +) + +go_repository( + name = "com_github_urfave_cli", + importpath = "github.com/urfave/cli", + sum = "h1:MCfT24H3f//U5+UCrZp1/riVO3B50BovxtDiNn0XKkk=", + version = "v0.0.0-20171014202726-7bc6a0acffa5", +) + +go_repository( + name = "com_github_yuin_goldmark", + importpath = "github.com/yuin/goldmark", + sum = "h1:5tjfNdR2ki3yYQ842+eX2sQHeiwpKJ0RnHO4IYOc4V8=", + version = "v1.1.32", +) + +go_repository( + name = "in_gopkg_airbrake_gobrake_v2", + importpath = "gopkg.in/airbrake/gobrake.v2", + sum = "h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo=", + version = "v2.0.9", +) + +go_repository( + name = "in_gopkg_fsnotify_v1", + importpath = "gopkg.in/fsnotify.v1", + sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=", + version = "v1.4.7", +) + +go_repository( + name = "in_gopkg_gemnasium_logrus_airbrake_hook_v2", + importpath = "gopkg.in/gemnasium/logrus-airbrake-hook.v2", + sum = "h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0=", + version = "v2.1.2", +) + +go_repository( + name = "in_gopkg_tomb_v1", + importpath = "gopkg.in/tomb.v1", + sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=", + version = "v1.0.0-20141024135613-dd632973f1e7", +) + +go_repository( + name = "in_gopkg_yaml_v2", + importpath = "gopkg.in/yaml.v2", + sum = "h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=", + version = "v2.2.8", +) + +go_repository( + name = "org_bazil_fuse", + importpath = "bazil.org/fuse", + sum = "h1:SC+c6A1qTFstO9qmB86mPV2IpYme/2ZoEQ0hrP+wo+Q=", + version = "v0.0.0-20160811212531-371fbbdaa898", +) + +go_repository( + name = "org_golang_google_appengine", + importpath = "google.golang.org/appengine", + sum = "h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=", + version = "v1.6.5", +) + +go_repository( + name = "org_golang_google_genproto", + importpath = "google.golang.org/genproto", + sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=", + version = "v0.0.0-20200117163144-32f20d992d24", +) + +go_repository( + name = "org_golang_google_protobuf", + importpath = "google.golang.org/protobuf", + sum = "h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=", + version = "v1.23.0", +) + +go_repository( + name = "org_golang_x_exp", + importpath = "golang.org/x/exp", + sum = "h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg=", + version = "v0.0.0-20191227195350-da58074b4299", +) + +go_repository( + name = "org_golang_x_lint", + importpath = "golang.org/x/lint", + sum = "h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE=", + version = "v0.0.0-20191125180803-fdd1cda4f05f", +) + +go_repository( + name = "tools_gotest", + importpath = "gotest.tools", + sum = "h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=", + version = "v2.2.0+incompatible", +) + +go_repository( + name = "com_github_burntsushi_xgb", + importpath = "github.com/BurntSushi/xgb", + sum = "h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc=", + version = "v0.0.0-20160522181843-27f122750802", +) + +go_repository( + name = "com_github_chzyer_logex", + importpath = "github.com/chzyer/logex", + sum = "h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=", + version = "v1.1.10", +) + +go_repository( + name = "com_github_chzyer_readline", + importpath = "github.com/chzyer/readline", + sum = "h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8=", + version = "v0.0.0-20180603132655-2972be24d48e", +) + +go_repository( + name = "com_github_chzyer_test", + importpath = "github.com/chzyer/test", + sum = "h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8=", + version = "v0.0.0-20180213035817-a1ea475d72b1", +) + +go_repository( + name = "com_github_go_gl_glfw_v3_3_glfw", + importpath = "github.com/go-gl/glfw/v3.3/glfw", + sum = "h1:b+9H1GAsx5RsjvDFLoS5zkNBzIQMuVKUYQDmxU3N5XE=", + version = "v0.0.0-20191125211704-12ad95a8df72", +) + +go_repository( + name = "com_github_golang_groupcache", + importpath = "github.com/golang/groupcache", + sum = "h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=", + version = "v0.0.0-20191227052852-215e87163ea7", +) + +go_repository( + name = "com_github_google_martian", + importpath = "github.com/google/martian", + sum = "h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no=", + version = "v2.1.0+incompatible", +) + +go_repository( + name = "com_github_google_pprof", + importpath = "github.com/google/pprof", + sum = "h1:DLpL8pWq0v4JYoRpEhDfsJhhJyGKCcQM2WPW2TJs31c=", + version = "v0.0.0-20191218002539-d4f498aebedc", +) + +go_repository( + name = "com_github_google_renameio", + importpath = "github.com/google/renameio", + sum = "h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=", + version = "v0.1.0", ) go_repository( @@ -407,46 +952,127 @@ go_repository( ) go_repository( - name = "io_opencensus_go", - importpath = "go.opencensus.io", - sum = "h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8=", - version = "v0.22.3", + name = "com_github_ianlancetaylor_demangle", + importpath = "github.com/ianlancetaylor/demangle", + sum = "h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c=", + version = "v0.0.0-20181102032728-5e5cf60278f6", ) go_repository( - name = "com_github_golang_groupcache", - importpath = "github.com/golang/groupcache", - sum = "h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=", - version = "v0.0.0-20200121045136-8c9f03a8e57e", + name = "com_github_jstemmer_go_junit_report", + importpath = "github.com/jstemmer/go-junit-report", + sum = "h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=", + version = "v0.9.1", ) -# System Call test dependencies. -http_archive( - name = "com_google_absl", - sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93", - strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e", - urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", - ], +go_repository( + name = "com_github_rogpeppe_go_internal", + importpath = "github.com/rogpeppe/go-internal", + sum = "h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk=", + version = "v1.3.0", ) -http_archive( - name = "com_google_googletest", - sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912", - strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34", - urls = [ - "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", - "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", - ], +go_repository( + name = "com_shuralyov_dmitri_gpu_mtl", + importpath = "dmitri.shuralyov.com/gpu/mtl", + sum = "h1:VpgP7xuJadIUuKccphEpTJnWhS2jkQyMt6Y7pJCD7fY=", + version = "v0.0.0-20190408044501-666a987793e9", ) -http_archive( - name = "com_google_benchmark", - sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", - strip_prefix = "benchmark-1.5.0", - urls = [ - "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", - "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", - ], +go_repository( + name = "in_gopkg_errgo_v2", + importpath = "gopkg.in/errgo.v2", + sum = "h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=", + version = "v2.1.0", +) + +go_repository( + name = "io_rsc_binaryregexp", + importpath = "rsc.io/binaryregexp", + sum = "h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=", + version = "v0.2.0", +) + +go_repository( + name = "org_golang_google_api", + importpath = "google.golang.org/api", + sum = "h1:yzlyyDW/J0w8yNFJIhiAJy4kq74S+1DOLdawELNxFMA=", + version = "v0.15.0", +) + +go_repository( + name = "org_golang_x_image", + importpath = "golang.org/x/image", + sum = "h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=", + version = "v0.0.0-20190802002840-cff245a6509b", +) + +go_repository( + name = "org_golang_x_mobile", + importpath = "golang.org/x/mobile", + sum = "h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=", + version = "v0.0.0-20190719004257-d2bd2a29d028", +) + +go_repository( + name = "com_github_containerd_typeurl", + importpath = "github.com/containerd/typeurl", + sum = "h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o=", + version = "v0.0.0-20200205145503-b45ef1f1f737", +) + +go_repository( + name = "com_github_vishvananda_netns", + importpath = "github.com/vishvananda/netns", + sum = "h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so=", + version = "v0.0.0-20200520041808-52d707b772fe", +) + +go_repository( + name = "com_google_cloud_go_bigquery", + importpath = "cloud.google.com/go/bigquery", + sum = "h1:hL+ycaJpVE9M7nLoiXb/Pn10ENE2u+oddxbD8uu0ZVU=", + version = "v1.0.1", +) + +go_repository( + name = "com_google_cloud_go_datastore", + importpath = "cloud.google.com/go/datastore", + sum = "h1:Kt+gOPPp2LEPWp8CSfxhsM8ik9CcyE/gYu+0r+RnZvM=", + version = "v1.0.0", +) + +go_repository( + name = "com_google_cloud_go_pubsub", + importpath = "cloud.google.com/go/pubsub", + sum = "h1:W9tAK3E57P75u0XLLR82LZyw8VpAnhmyTOxW9qzmyj8=", + version = "v1.0.1", +) + +go_repository( + name = "com_google_cloud_go_storage", + importpath = "cloud.google.com/go/storage", + sum = "h1:VV2nUM3wwLLGh9lSABFgZMjInyUbJeaRSE64WuAIQ+4=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_errwrap", + importpath = "github.com/hashicorp/errwrap", + sum = "h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_go_multierror", + importpath = "github.com/hashicorp/go-multierror", + sum = "h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_dpjacques_clockwork", + importpath = "github.com/dpjacques/clockwork", + sum = "h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U=", + version = "v0.1.1-0.20190114191937-d864eecc357b", ) diff --git a/benchmarks/BUILD b/benchmarks/BUILD deleted file mode 100644 index 389351210..000000000 --- a/benchmarks/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -package(licenses = ["notice"]) - -config_setting( - name = "gcloud_rule", - values = { - "define": "gcloud=off", - }, -) - -py_binary( - name = "benchmarks", - testonly = 1, - srcs = ["run.py"], - data = select({ - ":gcloud_rule": [], - "//conditions:default": [ - "//tools/vm:ubuntu1604", - "//tools/vm:zone", - ], - }), - main = "run.py", - python_version = "PY3", - srcs_version = "PY3", - tags = [ - "local", - "manual", - ], - deps = ["//benchmarks/runner"], -) diff --git a/benchmarks/README.md b/benchmarks/README.md deleted file mode 100644 index 814bcb220..000000000 --- a/benchmarks/README.md +++ /dev/null @@ -1,186 +0,0 @@ -# Benchmark tools - -These scripts are tools for collecting performance data for Docker-based tests. - -## Setup - -The scripts assume the following: - -* There are two sets of machines: one where the scripts will be run - (controller) and one or more machines on which docker containers will be run - (environment). -* The controller machine must have bazel installed along with this source - code. You should be able to run a command like `bazel run //benchmarks -- - --list` -* Environment machines must have docker and the required runtimes installed. - More specifically, you should be able to run a command like: `docker run - --runtime=$RUNTIME your/image`. -* The controller has ssh private key which can be used to login to environment - machines and run docker commands without using `sudo`. This is not required - if running locally via the `run-local` command. -* The docker daemon on each of your environment machines is listening on - `unix:///var/run/docker.sock` (docker's default). - -For configuring the environment manually, consult the -[dockerd documentation][dockerd]. - -## Running benchmarks - -### Locally - -The tool is built to, by default, use Google Cloud Platform to run benchmarks, -but it does support GCP workflows. To run locally, run the following from the -benchmarks directory: - -```bash -bazel run --define gcloud=off //benchmarks -- run-local startup - -... -method,metric,result -startup.empty,startup_time_ms,652.5772 -startup.node,startup_time_ms,1654.4042000000002 -startup.ruby,startup_time_ms,1429.835 -``` - -The above command ran the startup benchmark locally, which consists of three -benchmarks (empty, node, and ruby). Benchmark tools ran it on the default -runtime, runc. Running on another installed runtime, like say runsc, is as -simple as: - -```bash -bazel run --define gcloud=off //benchmarks -- run-local startup --runtime=runsc -``` - -There is help: - -```bash -bazel run --define gcloud=off //benchmarks -- --help -bazel run --define gcloud=off //benchmarks -- run-local --help -``` - -To list available benchmarks, use the `list` commmand: - -```bash -bazel --define gcloud=off run //benchmarks -- list - -... -Benchmark: sysbench.cpu -Metrics: events_per_second - Run sysbench CPU test. Additional arguments can be provided for sysbench. - - :param max_prime: The maximum prime number to search. -``` - -You can choose benchmarks by name or regex like: - -```bash -bazel run --define gcloud=off //benchmarks -- run-local startup.node -... -metric,result -startup_time_ms,1671.7178000000001 - -``` - -or - -```bash -bazel run --define gcloud=off //benchmarks -- run-local s -... -method,metric,result -startup.empty,startup_time_ms,1792.8292 -startup.node,startup_time_ms,3113.5274 -startup.ruby,startup_time_ms,3025.2424 -sysbench.cpu,cpu_events_per_second,12661.47 -sysbench.memory,memory_ops_per_second,7228268.44 -sysbench.mutex,mutex_time,17.4835 -sysbench.mutex,mutex_latency,3496.7 -sysbench.mutex,mutex_deviation,0.04 -syscall.syscall,syscall_time_ns,2065.0 -``` - -You can run parameterized benchmarks, for example to run with different -runtimes: - -```bash -bazel run --define gcloud=off //benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu -``` - -Or with different parameters: - -```bash -bazel run --define gcloud=off //benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu -``` - -### On Google Compute Engine (GCE) - -Benchmarks may be run on GCE in an automated way. The default project configured -for `gcloud` will be used. - -An additional parameter `installers` may be provided to ensure that the latest -runtime is installed from the workspace. See the files in `tools/installers` for -supported install targets. - -```bash -bazel run //benchmarks -- run-gcp --installers=head --runtime=runsc sysbench.cpu -``` - -When running on GCE, the scripts generate a per run SSH key, which is added to -your project. The key is set to expire in GCE after 60 minutes and is stored in -a temporary directory on the local machine running the scripts. - -## Writing benchmarks - -To write new benchmarks, you should familiarize yourself with the structure of -the repository. There are three key components. - -## Harness - -The harness makes use of the [docker py SDK][docker-py]. It is advisable that -you familiarize yourself with that API when making changes, specifically: - -* clients -* containers -* images - -In general, benchmarks need only interact with the `Machine` objects provided to -the benchmark function, which are the machines defined in the environment. These -objects allow the benchmark to define the relationships between different -containers, and parse the output. - -## Workloads - -The harness requires workloads to run. These are all available in the -`workloads` directory. - -In general, a workload consists of a Dockerfile to build it (while these are not -hermetic, in general they should be as fixed and isolated as possible), some -parsers for output if required, parser tests and sample data. Provided the test -is named after the workload package and contains a function named `sample`, this -variable will be used to automatically mock workload output when the `--mock` -flag is provided to the main tool. - -## Writing benchmarks - -Benchmarks define the tests themselves. All benchmarks have the following -function signature: - -```python -def my_func(output) -> float: - return float(output) - -@benchmark(metrics = my_func, machines = 1) -def my_benchmark(machine: machine.Machine, arg: str): - return "3.4432" -``` - -Each benchmark takes a variable amount of position arguments as -`harness.Machine` objects and some set of keyword arguments. It is recommended -that you accept arbitrary keyword arguments and pass them through when -constructing the container under test. - -To write a new benchmark, open a module in the `suites` directory and use the -above signature. You should add a descriptive doc string to describe what your -benchmark is and any test centric arguments. - -[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/ -[docker-py]: https://docker-py.readthedocs.io/en/stable/ diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl deleted file mode 100644 index 56d28223e..000000000 --- a/benchmarks/defs.bzl +++ /dev/null @@ -1,14 +0,0 @@ -"""Provides attributes common to many workload tests.""" - -load("//tools:defs.bzl", "py_requirement") - -test_deps = [ - py_requirement("attrs", direct = False), - py_requirement("atomicwrites", direct = False), - py_requirement("more-itertools", direct = False), - py_requirement("pathlib2", direct = False), - py_requirement("pluggy", direct = False), - py_requirement("py", direct = False), - py_requirement("pytest"), - py_requirement("six", direct = False), -] diff --git a/benchmarks/examples/localhost.yaml b/benchmarks/examples/localhost.yaml deleted file mode 100644 index f70fe0fb7..000000000 --- a/benchmarks/examples/localhost.yaml +++ /dev/null @@ -1,2 +0,0 @@ -client: localhost -server: localhost diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD deleted file mode 100644 index 48c548d59..000000000 --- a/benchmarks/harness/BUILD +++ /dev/null @@ -1,202 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "installers", - srcs = [ - "//tools/installers:head", - "//tools/installers:master", - "//tools/installers:runsc", - ], - mode = "0755", -) - -filegroup( - name = "files", - srcs = [ - ":installers", - ], -) - -py_library( - name = "harness", - srcs = ["__init__.py"], - data = [ - ":files", - ], -) - -py_library( - name = "benchmark_driver", - srcs = ["benchmark_driver.py"], - deps = [ - "//benchmarks/harness/machine_mocks", - "//benchmarks/harness/machine_producers:machine_producer", - "//benchmarks/suites", - ], -) - -py_library( - name = "container", - srcs = ["container.py"], - deps = [ - "//benchmarks/workloads", - py_requirement( - "asn1crypto", - direct = False, - ), - py_requirement( - "chardet", - direct = False, - ), - py_requirement( - "certifi", - direct = False, - ), - py_requirement("docker"), - py_requirement( - "docker-pycreds", - direct = False, - ), - py_requirement( - "idna", - direct = False, - ), - py_requirement( - "ptyprocess", - direct = False, - ), - py_requirement( - "requests", - direct = False, - ), - py_requirement( - "urllib3", - direct = False, - ), - py_requirement( - "websocket-client", - direct = False, - ), - ], -) - -py_library( - name = "machine", - srcs = ["machine.py"], - deps = [ - "//benchmarks/harness", - "//benchmarks/harness:container", - "//benchmarks/harness:ssh_connection", - "//benchmarks/harness:tunnel_dispatcher", - "//benchmarks/harness/machine_mocks", - py_requirement( - "asn1crypto", - direct = False, - ), - py_requirement( - "chardet", - direct = False, - ), - py_requirement( - "certifi", - direct = False, - ), - py_requirement("docker"), - py_requirement( - "docker-pycreds", - direct = False, - ), - py_requirement( - "idna", - direct = False, - ), - py_requirement( - "ptyprocess", - direct = False, - ), - py_requirement( - "requests", - direct = False, - ), - py_requirement( - "six", - direct = False, - ), - py_requirement( - "urllib3", - direct = False, - ), - py_requirement( - "websocket-client", - direct = False, - ), - ], -) - -py_library( - name = "ssh_connection", - srcs = ["ssh_connection.py"], - deps = [ - "//benchmarks/harness", - py_requirement( - "bcrypt", - direct = False, - ), - py_requirement("cffi"), - py_requirement("paramiko"), - py_requirement( - "cryptography", - direct = False, - ), - ], -) - -py_library( - name = "tunnel_dispatcher", - srcs = ["tunnel_dispatcher.py"], - deps = [ - py_requirement( - "asn1crypto", - direct = False, - ), - py_requirement( - "chardet", - direct = False, - ), - py_requirement( - "certifi", - direct = False, - ), - py_requirement("docker"), - py_requirement( - "docker-pycreds", - direct = False, - ), - py_requirement( - "idna", - direct = False, - ), - py_requirement("pexpect"), - py_requirement( - "ptyprocess", - direct = False, - ), - py_requirement( - "requests", - direct = False, - ), - py_requirement( - "urllib3", - direct = False, - ), - py_requirement( - "websocket-client", - direct = False, - ), - ], -) diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py deleted file mode 100644 index 15aa2a69a..000000000 --- a/benchmarks/harness/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Core benchmark utilities.""" - -import getpass -import os -import subprocess -import tempfile - -# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a -# format string that accepts a single string parameter. -LOCAL_WORKLOADS_PATH = os.path.dirname(__file__) + "/../workloads/{}/tar.tar" - -# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the -# remote host. This is a format string that accepts a single string parameter. -REMOTE_WORKLOADS_PATH = "workloads/{}" - -# INSTALLER_ROOT is the set of files that needs to be copied. -INSTALLER_ARCHIVE = os.readlink(os.path.join( - os.path.dirname(__file__), "installers.tar")) - -# SSH_KEY_DIR holds SSH_PRIVATE_KEY for this run. bm-tools paramiko requires -# keys generated with the '-t rsa -m PEM' options from ssh-keygen. This is -# abstracted away from the user. -SSH_KEY_DIR = tempfile.TemporaryDirectory() -SSH_PRIVATE_KEY = "key" - -# DEFAULT_USER is the default user running this script. -DEFAULT_USER = getpass.getuser() - -# DEFAULT_USER_HOME is the home directory of the user running the script. -DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else "" - -# Default directory to remotely installer "installer" targets. -REMOTE_INSTALLERS_PATH = "installers" - - -def make_key(): - """Wraps a valid ssh key in a temporary directory.""" - path = os.path.join(SSH_KEY_DIR.name, SSH_PRIVATE_KEY) - if not os.path.exists(path): - cmd = "ssh-keygen -t rsa -m PEM -b 4096 -f {key} -q -N".format( - key=path).split(" ") - cmd.append("") - subprocess.run(cmd, check=True) - return path - - -def delete_key(): - """Deletes temporary directory containing private key.""" - SSH_KEY_DIR.cleanup() diff --git a/benchmarks/harness/benchmark_driver.py b/benchmarks/harness/benchmark_driver.py deleted file mode 100644 index 9abc21b54..000000000 --- a/benchmarks/harness/benchmark_driver.py +++ /dev/null @@ -1,85 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Main driver for benchmarks.""" - -import copy -import statistics -import threading -import types - -from benchmarks import suites -from benchmarks.harness.machine_producers import machine_producer - - -# pylint: disable=too-many-instance-attributes -class BenchmarkDriver: - """Allocates machines and invokes a benchmark method.""" - - def __init__(self, - producer: machine_producer.MachineProducer, - method: types.FunctionType, - runs: int = 1, - **kwargs): - - self._producer = producer - self._method = method - self._kwargs = copy.deepcopy(kwargs) - self._threads = [] - self.lock = threading.RLock() - self._runs = runs - self._metric_results = {} - - def start(self): - """Starts a benchmark thread.""" - for _ in range(self._runs): - thread = threading.Thread(target=self._run_method) - thread.start() - self._threads.append(thread) - - def join(self): - """Joins the thread.""" - # pylint: disable=expression-not-assigned - [t.join() for t in self._threads] - - def _run_method(self): - """Runs all benchmarks.""" - machines = self._producer.get_machines( - suites.benchmark_machines(self._method)) - try: - result = self._method(*machines, **self._kwargs) - for name, res in result: - with self.lock: - if name in self._metric_results: - self._metric_results[name].append(res) - else: - self._metric_results[name] = [res] - finally: - # Always release. - self._producer.release_machines(machines) - - def median(self): - """Returns the median result, after join is finished.""" - for key, value in self._metric_results.items(): - yield key, [statistics.median(value)] - - def all(self): - """Returns all results.""" - for key, value in self._metric_results.items(): - yield key, value - - def meanstd(self): - """Returns all results.""" - for key, value in self._metric_results.items(): - mean = statistics.mean(value) - yield key, [mean, statistics.stdev(value, xbar=mean)] diff --git a/benchmarks/harness/container.py b/benchmarks/harness/container.py deleted file mode 100644 index 585436e20..000000000 --- a/benchmarks/harness/container.py +++ /dev/null @@ -1,181 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Container definitions.""" - -import contextlib -import logging -import pydoc -import types -from typing import Tuple - -import docker -import docker.errors - -from benchmarks import workloads - - -class Container: - """Abstract container. - - Must be a context manager. - - Usage: - - with Container(client, image, ...): - ... - """ - - def run(self, **env) -> str: - """Run the container synchronously.""" - raise NotImplementedError - - def detach(self, **env): - """Run the container asynchronously.""" - raise NotImplementedError - - def address(self) -> Tuple[str, int]: - """Return the bound address for the container.""" - raise NotImplementedError - - def get_names(self) -> types.GeneratorType: - """Return names of all containers.""" - raise NotImplementedError - - -# pylint: disable=too-many-instance-attributes -class DockerContainer(Container): - """Class that handles creating a docker container.""" - - # pylint: disable=too-many-arguments - def __init__(self, - client: docker.DockerClient, - host: str, - image: str, - count: int = 1, - runtime: str = "runc", - port: int = 0, - **kwargs): - """Trys to setup "count" containers. - - Args: - client: A docker client from dockerpy. - host: The host address the image is running on. - image: The name of the image to run. - count: The number of containers to setup. - runtime: The container runtime to use. - port: The port to reserve. - **kwargs: Additional container options. - """ - assert count >= 1 - assert port == 0 or count == 1 - self._client = client - self._host = host - self._containers = [] - self._count = count - self._image = image - self._runtime = runtime - self._port = port - self._kwargs = kwargs - if port != 0: - self._ports = {"%d/tcp" % port: None} - else: - self._ports = {} - - @contextlib.contextmanager - def detach(self, **env): - env = ["%s=%s" % (key, value) for (key, value) in env.items()] - # Start all containers. - for _ in range(self._count): - try: - # Start the container in a detached mode. - container = self._client.containers.run( - self._image, - detach=True, - remove=True, - runtime=self._runtime, - ports=self._ports, - environment=env, - **self._kwargs) - logging.info("Started detached container %s -> %s", self._image, - container.attrs["Id"]) - self._containers.append(container) - except Exception as exc: - self._clean_containers() - raise exc - try: - # Wait for all containers to be up. - for container in self._containers: - while not container.attrs["State"]["Running"]: - container = self._client.containers.get(container.attrs["Id"]) - yield self - finally: - self._clean_containers() - - def address(self) -> Tuple[str, int]: - assert self._count == 1 - assert self._port != 0 - container = self._client.containers.get(self._containers[0].attrs["Id"]) - port = container.attrs["NetworkSettings"]["Ports"][ - "%d/tcp" % self._port][0]["HostPort"] - return (self._host, port) - - def get_names(self) -> types.GeneratorType: - for container in self._containers: - yield container.name - - def run(self, **env) -> str: - env = ["%s=%s" % (key, value) for (key, value) in env.items()] - return self._client.containers.run( - self._image, - runtime=self._runtime, - ports=self._ports, - remove=True, - environment=env, - **self._kwargs).decode("utf-8") - - def _clean_containers(self): - """Kills all containers.""" - for container in self._containers: - try: - container.kill() - except docker.errors.NotFound: - pass - - -class MockContainer(Container): - """Mock of Container.""" - - def __init__(self, workload: str): - self._workload = workload - - def __enter__(self): - return self - - def run(self, **env): - # Lookup sample data if any exists for the workload module. We use a - # well-defined test locate and a well-defined sample function. - mod = pydoc.locate(workloads.__name__ + "." + self._workload) - if hasattr(mod, "sample"): - return mod.sample(**env) - return "" # No output. - - def address(self) -> Tuple[str, int]: - return ("example.com", 80) - - def get_names(self) -> types.GeneratorType: - yield "mock" - - @contextlib.contextmanager - def detach(self, **env): - yield self diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py deleted file mode 100644 index 5bdc4aa85..000000000 --- a/benchmarks/harness/machine.py +++ /dev/null @@ -1,265 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Machine abstraction passed to benchmarks to run docker containers. - -Abstraction for interacting with test machines. Machines are produced -by Machine producers and represent a local or remote machine. Benchmark -methods in /benchmarks/suite are passed the required number of machines in order -to run the benchmark. Machines contain methods to run commands via bash, -possibly over ssh. Machines also hold a connection to the docker UNIX socket -to run contianers. - - Typical usage example: - - machine = Machine() - machine.run(cmd) - machine.pull(path) - container = machine.container() -""" - -import logging -import os -import re -import subprocess -import time -from typing import List, Tuple - -import docker - -from benchmarks import harness -from benchmarks.harness import container -from benchmarks.harness import machine_mocks -from benchmarks.harness import ssh_connection -from benchmarks.harness import tunnel_dispatcher - -log = logging.getLogger(__name__) - - -class Machine(object): - """The machine object is the primary object for benchmarks. - - Machine objects are passed to each metric function call and benchmarks use - machines to access real connections to those machines. - - Attributes: - _name: Name as a string - """ - _name = "" - - def run(self, cmd: str) -> Tuple[str, str]: - """Convenience method for running a bash command on a machine object. - - Some machines may point to the local machine, and thus, do not have ssh - connections. Run runs a command either local or over ssh and returns the - output stdout and stderr as strings. - - Args: - cmd: The command to run as a string. - - Returns: - The command output. - """ - raise NotImplementedError - - def read(self, path: str) -> str: - """Reads the contents of some file. - - This will be mocked. - - Args: - path: The path to the file to be read. - - Returns: - The file contents. - """ - raise NotImplementedError - - def pull(self, workload: str) -> str: - """Send the given workload to the machine, build and tag it. - - All images must be defined by the workloads directory. - - Args: - workload: The workload name. - - Returns: - The workload tag. - """ - raise NotImplementedError - - def container(self, image: str, **kwargs) -> container.Container: - """Returns a container object. - - Args: - image: The pulled image tag. - **kwargs: Additional container options. - - Returns: - :return: a container.Container object. - """ - raise NotImplementedError - - def sleep(self, amount: float): - """Sleeps the given amount of time.""" - time.sleep(amount) - - def __str__(self): - return self._name - - -class MockMachine(Machine): - """A mocked machine.""" - _name = "mock" - - def run(self, cmd: str) -> Tuple[str, str]: - return "", "" - - def read(self, path: str) -> str: - return machine_mocks.Readfile(path) - - def pull(self, workload: str) -> str: - return workload # Workload is the tag. - - def container(self, image: str, **kwargs) -> container.Container: - return container.MockContainer(image) - - def sleep(self, amount: float): - pass - - -def get_address(machine: Machine) -> str: - """Return a machine's default address.""" - default_route, _ = machine.run("ip route get 8.8.8.8") - return re.search(" src ([0-9.]+) ", default_route).group(1) - - -class LocalMachine(Machine): - """The local machine. - - Attributes: - _name: Name as a string - _docker_client: a pythonic connection to to the local dockerd unix socket. - See: https://github.com/docker/docker-py - """ - - def __init__(self, name): - self._name = name - self._docker_client = docker.from_env() - - def run(self, cmd: str) -> Tuple[str, str]: - process = subprocess.Popen( - cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - return stdout.decode("utf-8"), stderr.decode("utf-8") - - def read(self, path: str) -> bytes: - # Read the exact path locally. - return open(path, "r").read() - - def pull(self, workload: str) -> str: - # Run the docker build command locally. - logging.info("Building %s@%s locally...", workload, self._name) - with open(harness.LOCAL_WORKLOADS_PATH.format(workload), - "rb") as dockerfile: - self._docker_client.images.build( - fileobj=dockerfile, tag=workload, custom_context=True) - return workload # Workload is the tag. - - def container(self, image: str, **kwargs) -> container.Container: - # Return a local docker container directly. - return container.DockerContainer(self._docker_client, get_address(self), - image, **kwargs) - - def sleep(self, amount: float): - time.sleep(amount) - - -class RemoteMachine(Machine): - """Remote machine accessible via an SSH connection. - - Attributes: - _name: Name as a string - _ssh_connection: a paramiko backed ssh connection which can be used to run - commands on this machine - _tunnel: a python wrapper around a port forwarded ssh connection between a - local unix socket and the remote machine's dockerd unix socket. - _docker_client: a pythonic wrapper backed by the _tunnel. Allows sending - docker commands: see https://github.com/docker/docker-py - """ - - def __init__(self, name, **kwargs): - self._name = name - self._ssh_connection = ssh_connection.SSHConnection(name, **kwargs) - self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs) - self._tunnel.connect() - self._docker_client = self._tunnel.get_docker_client() - self._has_installers = False - - def run(self, cmd: str) -> Tuple[str, str]: - return self._ssh_connection.run(cmd) - - def read(self, path: str) -> str: - # Just cat remotely. - stdout, stderr = self._ssh_connection.run("cat '{}'".format(path)) - return stdout + stderr - - def install(self, - installer: str, - results: List[bool] = None, - index: int = -1): - """Method unique to RemoteMachine to handle installation of installers. - - Handles installers, which install things that may change between runs (e.g. - runsc). Usually called from gcloud_producer, which expects this method to - to store results. - - Args: - installer: the installer target to run. - results: Passed by the caller of where to store success. - index: Index for this method to store the result in the passed results - list. - """ - # This generates a tarball of the full installer root (which will generate - # be the full bazel root directory) and sends it over. - if not self._has_installers: - archive = self._ssh_connection.send_installers() - self.run("tar -xvf {archive} -C {dir}".format( - archive=archive, dir=harness.REMOTE_INSTALLERS_PATH)) - self._has_installers = True - - # Execute the remote installer. - self.run("sudo {dir}/{file}".format( - dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) - - if results: - results[index] = True - - def pull(self, workload: str) -> str: - # Push to the remote machine and build. - logging.info("Building %s@%s remotely...", workload, self._name) - remote_path = self._ssh_connection.send_workload(workload) - remote_dir = os.path.dirname(remote_path) - # Workloads are all tarballs. - self.run("tar -xvf {remote_path} -C {remote_dir}".format( - remote_path=remote_path, remote_dir=remote_dir)) - self.run("docker build --tag={} {}".format(workload, remote_dir)) - return workload # Workload is the tag. - - def container(self, image: str, **kwargs) -> container.Container: - # Return a remote docker container. - return container.DockerContainer(self._docker_client, get_address(self), - image, **kwargs) - - def sleep(self, amount: float): - time.sleep(amount) diff --git a/benchmarks/harness/machine_mocks/BUILD b/benchmarks/harness/machine_mocks/BUILD deleted file mode 100644 index c8ec4bc79..000000000 --- a/benchmarks/harness/machine_mocks/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "machine_mocks", - srcs = ["__init__.py"], -) diff --git a/benchmarks/harness/machine_mocks/__init__.py b/benchmarks/harness/machine_mocks/__init__.py deleted file mode 100644 index 00f0085d7..000000000 --- a/benchmarks/harness/machine_mocks/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Machine mock files.""" - -MEMINFO = """\ -MemTotal: 7652344 kB -MemFree: 7174724 kB -MemAvailable: 7152008 kB -Buffers: 7544 kB -Cached: 178856 kB -SwapCached: 0 kB -Active: 270928 kB -Inactive: 68436 kB -Active(anon): 153124 kB -Inactive(anon): 880 kB -Active(file): 117804 kB -Inactive(file): 67556 kB -Unevictable: 0 kB -Mlocked: 0 kB -SwapTotal: 0 kB -SwapFree: 0 kB -Dirty: 900 kB -Writeback: 0 kB -AnonPages: 153000 kB -Mapped: 129120 kB -Shmem: 1044 kB -Slab: 60864 kB -SReclaimable: 22792 kB -SUnreclaim: 38072 kB -KernelStack: 2672 kB -PageTables: 5756 kB -NFS_Unstable: 0 kB -Bounce: 0 kB -WritebackTmp: 0 kB -CommitLimit: 3826172 kB -Committed_AS: 663836 kB -VmallocTotal: 34359738367 kB -VmallocUsed: 0 kB -VmallocChunk: 0 kB -HardwareCorrupted: 0 kB -AnonHugePages: 0 kB -ShmemHugePages: 0 kB -ShmemPmdMapped: 0 kB -CmaTotal: 0 kB -CmaFree: 0 kB -HugePages_Total: 0 -HugePages_Free: 0 -HugePages_Rsvd: 0 -HugePages_Surp: 0 -Hugepagesize: 2048 kB -DirectMap4k: 94196 kB -DirectMap2M: 4624384 kB -DirectMap1G: 3145728 kB -""" - -CONTENTS = { - "/proc/meminfo": MEMINFO, -} - - -def Readfile(path: str) -> str: - """Reads a mock file. - - Args: - path: The target path. - - Returns: - Mocked file contents or None. - """ - return CONTENTS.get(path, None) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD deleted file mode 100644 index 81f19bd08..000000000 --- a/benchmarks/harness/machine_producers/BUILD +++ /dev/null @@ -1,84 +0,0 @@ -load("//tools:defs.bzl", "py_library", "py_requirement") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "harness", - srcs = ["__init__.py"], -) - -py_library( - name = "machine_producer", - srcs = ["machine_producer.py"], -) - -py_library( - name = "mock_producer", - srcs = ["mock_producer.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/harness/machine_producers:gcloud_producer", - "//benchmarks/harness/machine_producers:machine_producer", - ], -) - -py_library( - name = "yaml_producer", - srcs = ["yaml_producer.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/harness/machine_producers:machine_producer", - py_requirement( - "PyYAML", - direct = False, - ), - ], -) - -py_library( - name = "gcloud_mock_recorder", - srcs = ["gcloud_mock_recorder.py"], -) - -py_library( - name = "gcloud_producer", - srcs = ["gcloud_producer.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/harness/machine_producers:gcloud_mock_recorder", - "//benchmarks/harness/machine_producers:machine_producer", - ], -) - -filegroup( - name = "test_data", - srcs = [ - "testdata/get_five.json", - "testdata/get_one.json", - ], -) - -py_library( - name = "gcloud_producer_test_lib", - srcs = ["gcloud_producer_test.py"], - deps = [ - "//benchmarks/harness/machine_producers:machine_producer", - "//benchmarks/harness/machine_producers:mock_producer", - ], -) - -py_test( - name = "gcloud_producer_test", - srcs = [":gcloud_producer_test_lib"], - data = [ - ":test_data", - ], - python_version = "PY3", - tags = [ - "local", - "manual", - ], -) diff --git a/benchmarks/harness/machine_producers/__init__.py b/benchmarks/harness/machine_producers/__init__.py deleted file mode 100644 index 634ef4843..000000000 --- a/benchmarks/harness/machine_producers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/benchmarks/harness/machine_producers/gcloud_mock_recorder.py b/benchmarks/harness/machine_producers/gcloud_mock_recorder.py deleted file mode 100644 index fd9837a37..000000000 --- a/benchmarks/harness/machine_producers/gcloud_mock_recorder.py +++ /dev/null @@ -1,97 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A recorder and replay for testing the GCloudProducer. - -MockPrinter and MockReader handle printing and reading mock data for the -purposes of testing. MockPrinter is passed to GCloudProducer objects. The user -can then run scenarios and record them for playback in tests later. - -MockReader is passed to MockGcloudProducer objects and handles reading the -previously recorded mock data. - -It is left to the user to check if data printed is properly redacted for their -own use. The intended usecase for this class is data coming from gcloud -commands, which will contain public IPs and other instance data. - -The data format is json and printed/read from the ./test_data directory. The -data is the output of subprocess.CompletedProcess objects in json format. - - Typical usage example: - - recorder = MockPrinter() - producer = GCloudProducer(args, recorder) - machines = producer.get_machines(1) - with open("my_file.json") as fd: - recorder.write_out(fd) - - reader = MockReader(filename) - producer = MockGcloudProducer(args, mock) - machines = producer.get_machines(1) - assert len(machines) == 1 -""" - -import io -import json -import subprocess - - -class MockPrinter(object): - """Handles printing Mock data for MockGcloudProducer. - - Attributes: - _records: list of json object records for printing - """ - - def __init__(self): - self._records = [] - - def record(self, entry: subprocess.CompletedProcess): - """Records data and strips out ip addresses.""" - - record = { - "args": entry.args, - "stdout": entry.stdout.decode("utf-8"), - "returncode": str(entry.returncode) - } - self._records.append(record) - - def write_out(self, fd: io.FileIO): - """Prints out the data into the given filepath.""" - fd.write(json.dumps(self._records, indent=4)) - - -class MockReader(object): - """Handles reading Mock data for MockGcloudProducer. - - Attributes: - _records: List[json] records read from the passed in file. - """ - - def __init__(self, filepath: str): - with open(filepath, "rb") as file: - self._records = json.loads(file.read()) - self._i = 0 - - def __iter__(self): - return self - - def __next__(self, args) -> subprocess.CompletedProcess: - """Returns the next record as a CompletedProcess.""" - if self._i < len(self._records): - record = self._records[self._i] - stdout = record["stdout"].encode("ascii") - returncode = int(record["returncode"]) - return subprocess.CompletedProcess( - args=args, returncode=returncode, stdout=stdout, stderr=b"") - raise StopIteration() diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py deleted file mode 100644 index 44d72f575..000000000 --- a/benchmarks/harness/machine_producers/gcloud_producer.py +++ /dev/null @@ -1,250 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A machine producer which produces machine objects using `gcloud`. - -Machine producers produce valid harness.Machine objects which are backed by -real machines. This producer produces those machines on the given user's GCP -account using the `gcloud` tool. - -GCloudProducer creates instances on the given GCP account named like: -`machine-XXXXXXX-XXXX-XXXX-XXXXXXXXXXXX` in a randomized fashion such that name -collisions with user instances shouldn't happen. - - Typical usage example: - - producer = GCloudProducer(args) - machines = producer.get_machines(NUM_MACHINES) - # run stuff on machines with machines[i].run(CMD) - producer.release_machines(NUM_MACHINES) -""" -import datetime -import json -import subprocess -import threading -from typing import List, Dict, Any -import uuid - -from benchmarks.harness import machine -from benchmarks.harness.machine_producers import gcloud_mock_recorder -from benchmarks.harness.machine_producers import machine_producer - - -class GCloudProducer(machine_producer.MachineProducer): - """Implementation of MachineProducer backed by GCP. - - Produces Machine objects backed by GCP instances. - - Attributes: - image: image name as a string. - zone: string to a valid GCP zone. - machine_type: type of GCP to create (e.g. n1-standard-4). - installers: list of installers post-boot. - ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys. - ssh_user: string of user name for ssh_key - ssh_password: string of password for ssh key - internal: if true, use internal IPs of instances. Used if bm-tools is - running on a GCP vm when a firewall is set for external IPs. - mock: a mock printer which will print mock data if required. Mock data is - recorded output from subprocess calls (returncode, stdout, args). - condition: mutex for this class around machine creation and deleteion. - """ - - def __init__(self, - image: str, - zone: str, - machine_type: str, - installers: List[str], - ssh_key_file: str, - ssh_user: str, - ssh_password: str, - internal: bool, - mock: gcloud_mock_recorder.MockPrinter = None): - self.image = image - self.zone = zone - self.machine_type = machine_type - self.installers = installers - self.ssh_key_file = ssh_key_file - self.ssh_user = ssh_user - self.ssh_password = ssh_password - self.internal = internal - self.mock = mock - self.condition = threading.Condition() - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - """Returns requested number of machines backed by GCP instances.""" - if num_machines <= 0: - raise ValueError( - "Cannot ask for {num} machines!".format(num=num_machines)) - with self.condition: - names = self._get_unique_names(num_machines) - instances = self._build_instances(names) - self._add_ssh_key_to_instances(names) - machines = self._machines_from_instances(instances) - - # Install all bits in lock-step. - # - # This will perform paralell installations for however many machines we - # have, but it's easy to track errors because if installing (a, b, c), we - # won't install "c" until "b" is installed on all machines. - for installer in self.installers: - threads = [None] * len(machines) - results = [False] * len(machines) - for i in range(len(machines)): - threads[i] = threading.Thread( - target=machines[i].install, args=(installer, results, i)) - threads[i].start() - for thread in threads: - thread.join() - for result in results: - if not result: - raise NotImplementedError( - "Installers failed on at least one machine!") - - # Add this user to each machine's docker group. - for m in machines: - m.run("sudo setfacl -m user:$USER:rw /var/run/docker.sock") - - return machines - - def release_machines(self, machine_list: List[machine.Machine]): - """Releases the requested number of machines, deleting the instances.""" - if not machine_list: - return - cmd = "gcloud compute instances delete --quiet".split(" ") - names = [str(m) for m in machine_list] - cmd.extend(names) - cmd.append("--zone={zone}".format(zone=self.zone)) - self._run_command(cmd, detach=True) - - def _machines_from_instances( - self, instances: List[Dict[str, Any]]) -> List[machine.Machine]: - """Creates Machine Objects from json data describing created instances.""" - machines = [] - for instance in instances: - name = instance["name"] - external = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"] - internal = instance["networkInterfaces"][0]["networkIP"] - kwargs = { - "hostname": internal if self.internal else external, - "key_path": self.ssh_key_file, - "username": self.ssh_user, - "key_password": self.ssh_password - } - machines.append(machine.RemoteMachine(name=name, **kwargs)) - return machines - - def _get_unique_names(self, num_names) -> List[str]: - """Returns num_names unique names based on data from the GCP project.""" - return ["machine-" + str(uuid.uuid4()) for _ in range(0, num_names)] - - def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]: - """Creates instances using gcloud command. - - Runs the command `gcloud compute instances create` and returns json data - on created instances on success. Creates len(names) instances, one for each - name. - - Args: - names: list of names of instances to create. - - Returns: - List of json data describing created machines. - """ - if not names: - raise ValueError( - "_build_instances cannot create instances without names.") - cmd = "gcloud compute instances create".split(" ") - cmd.extend(names) - cmd.append("--image=" + self.image) - cmd.append("--zone=" + self.zone) - cmd.append("--machine-type=" + self.machine_type) - res = self._run_command(cmd) - data = res.stdout - data = str(data, "utf-8") if isinstance(data, (bytes, bytearray)) else data - return json.loads(data) - - def _add_ssh_key_to_instances(self, names: List[str]) -> None: - """Adds ssh key to instances by calling gcloud ssh command. - - Runs the command `gcloud compute ssh instance_name` on list of images by - name. Tries to ssh into given instance. - - Args: - names: list of machine names to which to add the ssh-key - self.ssh_key_file. - - Raises: - subprocess.CalledProcessError: when underlying subprocess call returns an - error other than 255 (Connection closed by remote host). - TimeoutError: when 3 unsuccessful tries to ssh into the host return 255. - """ - for name in names: - cmd = "gcloud compute ssh {user}@{name}".format( - user=self.ssh_user, name=name).split(" ") - if self.internal: - cmd.append("--internal-ip") - cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file)) - cmd.append("--zone={zone}".format(zone=self.zone)) - cmd.append("--command=uname") - timeout = datetime.timedelta(seconds=5 * 60) - start = datetime.datetime.now() - while datetime.datetime.now() <= timeout + start: - try: - self._run_command(cmd) - break - except subprocess.CalledProcessError: - if datetime.datetime.now() > timeout + start: - raise TimeoutError( - "Could not SSH into instance after 5 min: {name}".format( - name=name)) - - def _run_command(self, - cmd: List[str], - detach: bool = False) -> [None, subprocess.CompletedProcess]: - """Runs command as a subprocess. - - Runs command as subprocess and returns the result. - If this has a mock recorder, use the record method to record the subprocess - call. - - Args: - cmd: command to be run as a list of strings. - detach: if True, run the child process and don't wait for it to return. - - Returns: - Completed process object to be parsed by caller or None if detach=True. - - Raises: - CalledProcessError: if subprocess.run returns an error. - """ - cmd = cmd + ["--format=json"] - if detach: - p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if self.mock: - out, _ = p.communicate() - self.mock.record( - subprocess.CompletedProcess( - returncode=p.returncode, stdout=out, args=p.args)) - return - - res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if self.mock: - self.mock.record(res) - if res.returncode != 0: - raise subprocess.CalledProcessError( - cmd=" ".join(res.args), - output=res.stdout, - stderr=res.stderr, - returncode=res.returncode) - return res diff --git a/benchmarks/harness/machine_producers/gcloud_producer_test.py b/benchmarks/harness/machine_producers/gcloud_producer_test.py deleted file mode 100644 index c8adb2bdc..000000000 --- a/benchmarks/harness/machine_producers/gcloud_producer_test.py +++ /dev/null @@ -1,48 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests GCloudProducer using mock data. - -GCloudProducer produces machines using 'get_machines' and 'release_machines' -methods. The tests check recorded data (jsonified subprocess.CompletedProcess -objects) of the producer producing one and five machines. -""" -import os -import types - -from benchmarks.harness.machine_producers import machine_producer -from benchmarks.harness.machine_producers import mock_producer - -TEST_DIR = os.path.dirname(__file__) - - -def run_get_release(producer: machine_producer.MachineProducer, - num_machines: int, - validator: types.FunctionType = None): - machines = producer.get_machines(num_machines) - assert len(machines) == num_machines - if validator: - validator(machines=machines, cmd="uname -a", workload=None) - producer.release_machines(machines) - - -def test_run_one(): - mock = mock_producer.MockReader(TEST_DIR + "get_one.json") - producer = mock_producer.MockGCloudProducer(mock) - run_get_release(producer, 1) - - -def test_run_five(): - mock = mock_producer.MockReader(TEST_DIR + "get_five.json") - producer = mock_producer.MockGCloudProducer(mock) - run_get_release(producer, 5) diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py deleted file mode 100644 index f5591c026..000000000 --- a/benchmarks/harness/machine_producers/machine_producer.py +++ /dev/null @@ -1,51 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Abstract types.""" - -import threading -from typing import List - -from benchmarks.harness import machine - - -class MachineProducer: - """Abstract Machine producer.""" - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - """Returns the requested number of machines.""" - raise NotImplementedError - - def release_machines(self, machine_list: List[machine.Machine]): - """Releases the given set of machines.""" - raise NotImplementedError - - -class LocalMachineProducer(MachineProducer): - """Produces Local Machines.""" - - def __init__(self, limit: int): - self.limit_sem = threading.Semaphore(value=limit) - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - """Returns the request number of MockMachines.""" - - self.limit_sem.acquire() - return [machine.LocalMachine("local") for _ in range(num_machines)] - - def release_machines(self, machine_list: List[machine.MockMachine]): - """No-op.""" - if not machine_list: - raise ValueError("Cannot release an empty list!") - self.limit_sem.release() - machine_list.clear() diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py deleted file mode 100644 index 37e9cb4b7..000000000 --- a/benchmarks/harness/machine_producers/mock_producer.py +++ /dev/null @@ -1,52 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Producers of mocks.""" - -from typing import List, Any - -from benchmarks.harness import machine -from benchmarks.harness.machine_producers import gcloud_mock_recorder -from benchmarks.harness.machine_producers import gcloud_producer -from benchmarks.harness.machine_producers import machine_producer - - -class MockMachineProducer(machine_producer.MachineProducer): - """Produces MockMachine objects.""" - - def get_machines(self, num_machines: int) -> List[machine.MockMachine]: - """Returns the request number of MockMachines.""" - return [machine.MockMachine() for i in range(num_machines)] - - def release_machines(self, machine_list: List[machine.MockMachine]): - """No-op.""" - return - - -class MockGCloudProducer(gcloud_producer.GCloudProducer): - """Mocks GCloudProducer for testing purposes.""" - - def __init__(self, mock: gcloud_mock_recorder.MockReader, **kwargs): - gcloud_producer.GCloudProducer.__init__( - self, project="mock", ssh_private_key_path="mock", **kwargs) - self.mock = mock - - def _validate_ssh_file(self): - pass - - def _run_command(self, cmd): - return self.mock.pop(cmd) - - def _machines_from_instances( - self, instances: List[Any]) -> List[machine.MockMachine]: - return [machine.MockMachine() for _ in instances] diff --git a/benchmarks/harness/machine_producers/testdata/get_five.json b/benchmarks/harness/machine_producers/testdata/get_five.json deleted file mode 100644 index 32bad1b06..000000000 --- a/benchmarks/harness/machine_producers/testdata/get_five.json +++ /dev/null @@ -1,211 +0,0 @@ -[ - { - "args": [ - "gcloud", - "compute", - "instances", - "list", - "--project", - "project", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "create", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--preemptible", - "--image=ubuntu-1910-eoan-v20191204", - "--zone=us-west1-b", - "--image-project=ubuntu-os-cloud", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "start", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--zone=us-west1-b", - "--project=project", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "delete", - "--quiet", - "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc", - "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92", - "machine-da5859b5-bae6-435d-8005-0202d6f6e065", - "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05", - "machine-1149147d-71e2-43ea-8fe1-49256e5c441c", - "--zone=us-west1-b", - "--format=json" - ], - "stdout": "[]\n", - "returncode": "0" - } -] diff --git a/benchmarks/harness/machine_producers/testdata/get_one.json b/benchmarks/harness/machine_producers/testdata/get_one.json deleted file mode 100644 index c359c19c8..000000000 --- a/benchmarks/harness/machine_producers/testdata/get_one.json +++ /dev/null @@ -1,145 +0,0 @@ -[ - { - "args": [ - "gcloud", - "compute", - "instances", - "list", - "--project", - "linux-testing-user", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "create", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--preemptible", - "--image=ubuntu-1910-eoan-v20191204", - "--zone=us-west1-b", - "--image-project=ubuntu-os-cloud", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "start", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--zone=us-west1-b", - "--project=linux-testing-user", - "--format=json" - ], - "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]", - - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "", - "returncode": "255" - }, - { - "args": [ - "gcloud", - "compute", - "ssh", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools", - "--zone=us-west1-b", - "--command=uname", - "--format=json" - ], - "stdout": "Linux\n[]\n", - "returncode": "0" - }, - { - "args": [ - "gcloud", - "compute", - "instances", - "delete", - "--quiet", - "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc", - "--zone=us-west1-b", - "--format=json" - ], - "stdout": "[]\n", - "returncode": "0" - } -] diff --git a/benchmarks/harness/machine_producers/yaml_producer.py b/benchmarks/harness/machine_producers/yaml_producer.py deleted file mode 100644 index 5d334e480..000000000 --- a/benchmarks/harness/machine_producers/yaml_producer.py +++ /dev/null @@ -1,106 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Producers based on yaml files.""" - -import os -import threading -from typing import Dict -from typing import List - -import yaml - -from benchmarks.harness import machine -from benchmarks.harness.machine_producers import machine_producer - - -class YamlMachineProducer(machine_producer.MachineProducer): - """Loads machines from a yaml file.""" - - def __init__(self, path: str): - self.machines = build_machines(path) - self.max_machines = len(self.machines) - self.machine_condition = threading.Condition() - - def get_machines(self, num_machines: int) -> List[machine.Machine]: - if num_machines > self.max_machines: - raise ValueError( - "Insufficient Ammount of Machines. {ask} asked for and have {max_num} max." - .format(ask=num_machines, max_num=self.max_machines)) - - with self.machine_condition: - while not self._enough_machines(num_machines): - self.machine_condition.wait(timeout=1) - return [self.machines.pop(0) for _ in range(num_machines)] - - def release_machines(self, machine_list: List[machine.Machine]): - with self.machine_condition: - while machine_list: - next_machine = machine_list.pop() - self.machines.append(next_machine) - self.machine_condition.notify() - - def _enough_machines(self, ask: int): - return ask <= len(self.machines) - - -def build_machines(path: str, num_machines: str = -1) -> List[machine.Machine]: - """Builds machine objects defined by the yaml file "path". - - Args: - path: The path to a yaml file which defines machines. - num_machines: Optional limit on how many machine objects to build. - - Returns: - Machine objects in a list. - - If num_machines is set, len(machines) <= num_machines. - """ - data = parse_yaml(path) - machines = [] - for key, value in data.items(): - if len(machines) == num_machines: - return machines - if isinstance(value, dict): - machines.append(machine.RemoteMachine(key, **value)) - else: - machines.append(machine.LocalMachine(key)) - return machines - - -def parse_yaml(path: str) -> Dict[str, Dict[str, str]]: - """Parse the yaml file pointed by path. - - Args: - path: The path to yaml file. - - Returns: - The contents of the yaml file as a dictionary. - """ - data = get_file_contents(path) - return yaml.load(data, Loader=yaml.Loader) - - -def get_file_contents(path: str) -> str: - """Dumps the file contents to a string and returns them. - - Args: - path: The path to dump. - - Returns: - The file contents as a string. - """ - if not os.path.isabs(path): - path = os.path.abspath(path) - with open(path) as input_file: - return input_file.read() diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py deleted file mode 100644 index b8c8e42d4..000000000 --- a/benchmarks/harness/ssh_connection.py +++ /dev/null @@ -1,126 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SSHConnection handles the details of SSH connections.""" - -import logging -import os -import warnings - -import paramiko - -from benchmarks import harness - -# Get rid of paramiko Cryptography Warnings. -warnings.filterwarnings(action="ignore", module=".*paramiko.*") - -log = logging.getLogger(__name__) - - -def send_one_file(client: paramiko.SSHClient, path: str, - remote_dir: str) -> str: - """Sends a single file via an SSH client. - - Args: - client: The existing SSH client. - path: The local path. - remote_dir: The remote directory. - - Returns: - :return: The remote path as a string. - """ - filename = path.split("/").pop() - if remote_dir != ".": - client.exec_command("mkdir -p " + remote_dir) - with client.open_sftp() as ftp_client: - ftp_client.put(path, os.path.join(remote_dir, filename)) - return os.path.join(remote_dir, filename) - - -class SSHConnection: - """SSH connection to a remote machine.""" - - def __init__(self, name: str, hostname: str, key_path: str, username: str, - **kwargs): - """Sets up a paramiko ssh connection to the given hostname.""" - self._name = name # Unused. - self._hostname = hostname - self._username = username - self._key_path = key_path # RSA Key path - self._kwargs = kwargs - # SSHConnection wraps paramiko. paramiko supports RSA, ECDSA, and Ed25519 - # keys, and we've chosen to only suport and require RSA keys. paramiko - # supports RSA keys that begin with '----BEGIN RSAKEY----'. - # https://stackoverflow.com/questions/53600581/ssh-key-generated-by-ssh-keygen-is-not-recognized-by-paramiko - self.rsa_key = self._rsa() - self.run("true") # Validate. - - def _client(self) -> paramiko.SSHClient: - """Returns a connected SSH client.""" - client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - hostname=self._hostname, - port=22, - username=self._username, - pkey=self.rsa_key, - allow_agent=False, - look_for_keys=False) - return client - - def _rsa(self): - if "key_password" in self._kwargs: - password = self._kwargs["key_password"] - else: - password = None - rsa = paramiko.RSAKey.from_private_key_file(self._key_path, password) - return rsa - - def run(self, cmd: str) -> (str, str): - """Runs a command via ssh. - - Args: - cmd: The shell command to run. - - Returns: - The contents of stdout and stderr. - """ - with self._client() as client: - log.info("running command: %s", cmd) - _, stdout, stderr = client.exec_command(command=cmd) - log.info("returned status: %d", stdout.channel.recv_exit_status()) - stdout = stdout.read().decode("utf-8") - stderr = stderr.read().decode("utf-8") - log.info("stdout: %s", stdout) - log.info("stderr: %s", stderr) - return stdout, stderr - - def send_workload(self, name: str) -> str: - """Sends a workload tarball to the remote machine. - - Args: - name: The workload name. - - Returns: - The remote path. - """ - with self._client() as client: - return send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name), - harness.REMOTE_WORKLOADS_PATH.format(name)) - - def send_installers(self) -> str: - with self._client() as client: - return send_one_file( - client, - path=harness.INSTALLER_ARCHIVE, - remote_dir=harness.REMOTE_INSTALLERS_PATH) diff --git a/benchmarks/harness/tunnel_dispatcher.py b/benchmarks/harness/tunnel_dispatcher.py deleted file mode 100644 index c56fd022a..000000000 --- a/benchmarks/harness/tunnel_dispatcher.py +++ /dev/null @@ -1,122 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tunnel handles setting up connections to remote machines. - -Tunnel dispatcher is a wrapper around the connection from a local UNIX socket -and a remote UNIX socket via SSH with port forwarding. This is done to -initialize the pythonic dockerpy client to run containers on the remote host by -connecting to /var/run/docker.sock (where Docker is listening). Tunnel -dispatcher sets up the local UNIX socket and calls the `ssh` command as a -subprocess, and holds a reference to that subprocess. It manages clean-up on -exit as best it can by killing the ssh subprocess and deleting the local UNIX -socket,stored in /tmp for easy cleanup in most systems if this fails. - - Typical usage example: - - t = Tunnel(name, **kwargs) - t.connect() - client = t.get_docker_client() # - client.containers.run("ubuntu", "echo hello world") - -""" - -import os -import tempfile -import time - -import docker -import pexpect - -SSH_TUNNEL_COMMAND = """ssh - -o GlobalKnownHostsFile=/dev/null - -o UserKnownHostsFile=/dev/null - -o StrictHostKeyChecking=no - -o IdentitiesOnly=yes - -nNT -L {filename}:/var/run/docker.sock - -i {key_path} - {username}@{hostname}""" - - -class Tunnel(object): - """The tunnel object represents the tunnel via ssh. - - This connects a local unix domain socket with a remote socket. - - Attributes: - _filename: a temporary name of the UNIX socket prefixed by the name - argument. - _hostname: the IP or resolvable hostname of the remote host. - _username: the username of the ssh_key used to run ssh. - _key_path: path to a valid key. - _key_password: optional password to the ssh key in _key_path - _process: holds reference to the ssh subprocess created. - - Returns: - The new minimum port. - - Raises: - ConnectionError: If no available port is found. - """ - - def __init__(self, - name: str, - hostname: str, - username: str, - key_path: str, - key_password: str = "", - **kwargs): - self._filename = tempfile.NamedTemporaryFile(prefix=name).name - self._hostname = hostname - self._username = username - self._key_path = key_path - self._key_password = key_password - self._kwargs = kwargs - self._process = None - - def connect(self): - """Connects the SSH tunnel and stores the subprocess reference in _process.""" - cmd = SSH_TUNNEL_COMMAND.format( - filename=self._filename, - key_path=self._key_path, - username=self._username, - hostname=self._hostname) - self._process = pexpect.spawn(cmd, timeout=10) - - # If given a password, assume we'll be asked for it. - if self._key_password: - self._process.expect(["Enter passphrase for key .*: "]) - self._process.sendline(self._key_password) - - while True: - # Wait for the tunnel to appear. - if self._process.exitstatus is not None: - raise ConnectionError("Error in setting up ssh tunnel") - if os.path.exists(self._filename): - return - time.sleep(0.1) - - def path(self): - """Return the socket file.""" - return self._filename - - def get_docker_client(self): - """Returns a docker client for this Tunnel.""" - return docker.DockerClient(base_url="unix:/" + self._filename) - - def __del__(self): - """Closes the ssh connection process and deletes the socket file.""" - if self._process: - self._process.close() - if os.path.exists(self._filename): - os.remove(self._filename) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt deleted file mode 100644 index 577eb1a2e..000000000 --- a/benchmarks/requirements.txt +++ /dev/null @@ -1,32 +0,0 @@ -asn1crypto==1.2.0 -atomicwrites==1.3.0 -attrs==19.3.0 -bcrypt==3.1.7 -certifi==2019.9.11 -cffi==1.13.2 -chardet==3.0.4 -Click==7.0 -cryptography==2.8 -docker==3.7.0 -docker-pycreds==0.4.0 -idna==2.8 -importlib-metadata==0.23 -more-itertools==7.2.0 -packaging==19.2 -paramiko==2.6.0 -pathlib2==2.3.5 -pexpect==4.7.0 -pluggy==0.9.0 -ptyprocess==0.6.0 -py==1.8.0 -pycparser==2.19 -PyNaCl==1.3.0 -pyparsing==2.4.5 -pytest==4.3.0 -PyYAML==5.1.2 -requests==2.22.0 -six==1.13.0 -urllib3==1.25.7 -wcwidth==0.1.7 -websocket-client==0.56.0 -zipp==0.6.0 diff --git a/benchmarks/run.py b/benchmarks/run.py deleted file mode 100644 index a22eb8641..000000000 --- a/benchmarks/run.py +++ /dev/null @@ -1,19 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Benchmark runner.""" - -from benchmarks import runner - -if __name__ == "__main__": - runner.runner() diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD deleted file mode 100644 index 471debfdf..000000000 --- a/benchmarks/runner/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "py_library", "py_requirement", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package(licenses = ["notice"]) - -py_library( - name = "runner", - srcs = ["__init__.py"], - data = [ - "//benchmarks/workloads:files", - ], - visibility = ["//benchmarks:__pkg__"], - deps = [ - ":commands", - "//benchmarks/harness:benchmark_driver", - "//benchmarks/harness/machine_producers:machine_producer", - "//benchmarks/harness/machine_producers:mock_producer", - "//benchmarks/harness/machine_producers:yaml_producer", - "//benchmarks/suites", - "//benchmarks/suites:absl", - "//benchmarks/suites:density", - "//benchmarks/suites:fio", - "//benchmarks/suites:helpers", - "//benchmarks/suites:http", - "//benchmarks/suites:media", - "//benchmarks/suites:ml", - "//benchmarks/suites:network", - "//benchmarks/suites:redis", - "//benchmarks/suites:startup", - "//benchmarks/suites:sysbench", - "//benchmarks/suites:syscall", - py_requirement("click"), - ], -) - -py_library( - name = "commands", - srcs = ["commands.py"], - deps = [ - py_requirement("click"), - ], -) - -py_test( - name = "runner_test", - srcs = ["runner_test.py"], - python_version = "PY3", - tags = [ - "local", - "manual", - ], - deps = test_deps + [ - ":runner", - py_requirement("click"), - ], -) diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py deleted file mode 100644 index fc59cf505..000000000 --- a/benchmarks/runner/__init__.py +++ /dev/null @@ -1,308 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""High-level benchmark utility.""" - -import copy -import csv -import logging -import pkgutil -import pydoc -import re -import subprocess -import sys -import types -from typing import List -from typing import Tuple - -import click - -from benchmarks import harness -from benchmarks import suites -from benchmarks.harness import benchmark_driver -from benchmarks.harness.machine_producers import gcloud_producer -from benchmarks.harness.machine_producers import machine_producer -from benchmarks.harness.machine_producers import mock_producer -from benchmarks.harness.machine_producers import yaml_producer -from benchmarks.runner import commands - - -@click.group() -@click.option( - "--verbose/--no-verbose", default=False, help="Enable verbose logging.") -@click.option("--debug/--no-debug", default=False, help="Enable debug logging.") -def runner(verbose: bool = False, debug: bool = False): - """Run distributed benchmarks. - - See the run and list commands for details. - - Args: - verbose: Enable verbose logging. - debug: Enable debug logging (supercedes verbose). - """ - if debug: - logging.basicConfig(level=logging.DEBUG) - elif verbose: - logging.basicConfig(level=logging.INFO) - - -def find_benchmarks( - regex: str) -> List[Tuple[str, types.ModuleType, types.FunctionType]]: - """Finds all available benchmarks. - - Args: - regex: A regular expression to match. - - Returns: - A (short_name, module, function) tuple for each match. - """ - pkgs = pkgutil.walk_packages(suites.__path__, suites.__name__ + ".") - found = [] - for _, name, _ in pkgs: - mod = pydoc.locate(name) - funcs = [ - getattr(mod, x) - for x in dir(mod) - if suites.is_benchmark(getattr(mod, x)) - ] - for func in funcs: - # Use the short_name with the benchmarks. prefix stripped. - prefix_len = len(suites.__name__ + ".") - short_name = mod.__name__[prefix_len:] + "." + func.__name__ - # Add to the list if a pattern is provided. - if re.compile(regex).match(short_name): - found.append((short_name, mod, func)) - return found - - -@runner.command("list") -@click.argument("method", nargs=-1) -def list_all(method): - """Lists available benchmarks.""" - if not method: - method = ".*" - else: - method = "(" + ",".join(method) + ")" - for (short_name, _, func) in find_benchmarks(method): - print("Benchmark %s:" % short_name) - metrics = suites.benchmark_metrics(func) - if func.__doc__: - print(" " + func.__doc__.lstrip().rstrip()) - if metrics: - print("\n Metrics:") - for metric in metrics: - print("\t{name}: {doc}".format(name=metric[0], doc=metric[1])) - print("\n") - - -@runner.command("run-local", commands.LocalCommand) -@click.pass_context -def run_local(ctx, limit: float, **kwargs): - """Runs benchmarks locally.""" - run(ctx, machine_producer.LocalMachineProducer(limit=limit), **kwargs) - - -@runner.command("run-mock", commands.RunCommand) -@click.pass_context -def run_mock(ctx, **kwargs): - """Runs benchmarks on Mock machines. Used for testing.""" - run(ctx, mock_producer.MockMachineProducer(), **kwargs) - - -@runner.command("run-gcp", commands.GCPCommand) -@click.pass_context -def run_gcp(ctx, image_file: str, zone_file: str, internal: bool, - machine_type: str, installers: List[str], **kwargs): - """Runs all benchmarks on GCP instances.""" - - # Resolve all files. - image = subprocess.check_output([image_file]).rstrip() - zone = subprocess.check_output([zone_file]).rstrip() - key_file = harness.make_key() - - producer = gcloud_producer.GCloudProducer( - image, - zone, - machine_type, - installers, - ssh_key_file=key_file, - ssh_user=harness.DEFAULT_USER, - ssh_password="", - internal=internal) - - try: - run(ctx, producer, **kwargs) - finally: - harness.delete_key() - - -def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int, - runtime: List[str], metric: List[str], stat: str, **kwargs): - """Runs arbitrary benchmarks. - - All unknown command line flags are passed through to the underlying benchmark - method. Flags may be specified multiple times, in which case it is considered - a "dimension" for the test, and a comma-separated table will be emitted - instead of a single result. - - See the output of list to see available metrics for any given benchmark - method. The method parameter is a regular expression that will match against - available benchmarks. If multiple benchmarks match, then that is considered a - distinct "dimension" for the test. - - All benchmarks are run in parallel where possible, but have exclusive - ownership over the individual machines. - - Every benchmark method will be run the times indicated by --runs. - - Args: - ctx: Click context. - producer: A Machine Producer from which to get Machines. - method: A regular expression for methods to be run. - runs: Number of runs. - runtime: A list of runtimes to test. - metric: A list of metrics to extract. - stat: The class of statistics to extract. - **kwargs: Dimensions to test. - """ - # First, calculate additional arguments. - # - # This essentially calculates any arguments that appear multiple times, and - # moves those to the "dimensions" dictionary, which maps to lists. These - # dimensions are then iterated over to generate the relevant csv output. - dimensions = {} - - if stat not in ["median", "all", "meanstd"]: - raise ValueError("Illegal value for --result, see help.") - - def squish(key: str, value: str): - """Collapse an argument into kwargs or dimensions.""" - if key in dimensions: - # Extend an existing dimension. - dimensions[key].append(value) - elif key in kwargs: - # Create a new dimension. - dimensions[key] = [kwargs[key], value] - del kwargs[key] - else: - # A single value. - kwargs[key] = value - - for item in ctx.args: - if "=" in method: - # This must be the method. The method is simply set to the first - # non-matching argument, which we're also parsing here. - item, method = method, item - if "=" not in item: - logging.error("illegal argument: %s", item) - sys.exit(1) - (key, value) = item.lstrip("-").split("=", 1) - squish(key, value) - - # Convert runtime and metric to dimensions. - # - # They exist only in the arguments above for documentation purposes. - # Essentially here we are treating them like anything else. Note however, - # that an empty set here will result in a dimension. This is important for - # metrics, where an empty set actually means all metrics. - def fold(key: str, value, allow_flatten=False): - """Collapse a list value into kwargs or dimensions.""" - if len(value) == 1 and allow_flatten: - kwargs[key] = value[0] - else: - dimensions[key] = value - - fold("runtime", runtime, allow_flatten=True) - fold("metric", metric) - - # Lookup the methods. - # - # We match the method parameter to a regular expression. This allows you to - # do things like `run --mock .*` for a broad test. Note that we track the - # short_names in the dimensions here, and look up again in the recursion. - methods = { - short_name: func for (short_name, _, func) in find_benchmarks(method) - } - if not methods: - # Must match at least one method. - logging.error("no matching benchmarks for %s: try list.", method) - sys.exit(1) - fold("method", list(methods.keys()), allow_flatten=True) - - # Spin up the drivers. - # - # We ensure that metric is the last entry, because we have special behavior. - # They actually run the test once and the benchmark is a generator that - # produces all viable metrics. - dimension_keys = list(dimensions.keys()) - if "metric" in dimension_keys: - dimension_keys.remove("metric") - dimension_keys.append("metric") - drivers = [] - - def _start(keywords, finished, left): - """Runs a test across dimensions recursively.""" - # Resolve the method fully, it starts as a string. - if "method" in keywords and isinstance(keywords["method"], str): - keywords["method"] = methods[keywords["method"]] - # Is this a non-recursive case? - if not left: - driver = benchmark_driver.BenchmarkDriver(producer, runs=runs, **keywords) - driver.start() - drivers.append((finished, driver)) - else: - # Recurse on the next dimension. - current, left = left[0], left[1:] - keywords = copy.deepcopy(keywords) - if current == "metric": - # We use a generator, popped below. Note that metric is - # guaranteed to be the last element here, and we will provide - # the value for 'done' below when generating the csv. - keywords[current] = dimensions[current] - _start(keywords, finished, left) - else: - # Generate manually. - for value in dimensions[current]: - keywords[current] = value - _start(keywords, finished + [value], left) - - # Start all the drivers, recursively. - _start(kwargs, [], dimension_keys) - - # Finish all tests, write results. - output = csv.writer(sys.stdout) - output.writerow(dimension_keys + ["result"]) - for (done, driver) in drivers: - driver.join() - for (metric_name, result) in getattr(driver, stat)(): - output.writerow([ # Collapse the method name. - hasattr(x, "__name__") and x.__name__ or x for x in done - ] + [metric_name] + result) - - -@runner.command() -@click.argument("env") -@click.option( - "--cmd", default="uname -a", help="command to run on all found machines") -@click.option( - "--workload", default="true", help="workload to run all found machines") -def validate(env, cmd, workload): - """Validates an environment described by yaml file.""" - producer = yaml_producer.YamlMachineProducer(env) - for machine in producer.machines: - print("Machine %s:" % machine) - stdout, _ = machine.run(cmd) - print(" Output of '%s': %s" % (cmd, stdout.lstrip().rstrip())) - image = machine.pull(workload) - stdout = machine.container(image).run() - print(" Container %s: %s" % (workload, stdout.lstrip().rstrip())) diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py deleted file mode 100644 index 9a391eb01..000000000 --- a/benchmarks/runner/commands.py +++ /dev/null @@ -1,135 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Module with the guts of `click` commands. - -Overrides of the click.core.Command. This is done so flags are inherited between -similar commands (the run command). The classes below are meant to be used in -click templates like so. - -@runner.command("run-mock", RunCommand) -def run_mock(**kwargs): - # mock implementation - -""" -import os - -import click - - -class RunCommand(click.core.Command): - """Base Run Command with flags. - - Attributes: - method: regex of which suite to choose (e.g. sysbench would run - sysbench.cpu, sysbench.memory, and sysbench.mutex) See list command for - details. - metric: metric(s) to extract. See list command for details. - runtime: the runtime(s) on which to run. - runs: the number of runs to do of each method. - stat: how to compile results in the case of multiple run (e.g. median). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - method = click.core.Argument(("method",)) - - metric = click.core.Option(("--metric",), - help="The metric to extract.", - multiple=True) - - runtime = click.core.Option(("--runtime",), - default=["runc"], - help="The runtime to use.", - multiple=True) - runs = click.core.Option(("--runs",), - default=1, - help="The number of times to run each benchmark.") - stat = click.core.Option( - ("--stat",), - default="median", - help="How to aggregate the data from all runs." - "\nmedian - returns the median of all runs (default)" - "\nall - returns all results comma separated" - "\nmeanstd - returns result as mean,std") - self.params.extend([method, runtime, runs, stat, metric]) - self.ignore_unknown_options = True - self.allow_extra_args = True - - -class LocalCommand(RunCommand): - """LocalCommand inherits all flags from RunCommand. - - Attributes: - limit: limits the number of machines on which to run benchmarks. This limits - for local how many benchmarks may run at a time. e.g. "startup" requires - one machine -- passing two machines would limit two startup jobs at a - time. Default is infinity. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.params.append( - click.core.Option( - ("--limit",), - default=1, - help="Limit of number of benchmarks that can run at a given time.")) - - -class GCPCommand(RunCommand): - """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method. - - Attributes: - image_file: name of the image to build machines from - zone_file: a GCP zone (e.g. us-west1-b) - installers: named installers for post-create - machine_type: type of machine to create (e.g. n1-standard-4) - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - image_file = click.core.Option( - ("--image_file",), - help="The binary that emits the GCP image.", - default=os.path.join( - os.path.dirname(__file__), "../../tools/vm/ubuntu1604"), - ) - zone_file = click.core.Option( - ("--zone_file",), - help="The binary that emits the GCP zone.", - default=os.path.join(os.path.dirname(__file__), "../../tools/vm/zone"), - ) - internal = click.core.Option( - ("--internal/--no-internal",), - help="""Use instance internal IPs. Used if bm-tools runner is running on - GCP instance with firewall rules blocking external IPs.""", - default=False, - ) - installers = click.core.Option( - ("--installers",), - help="The set of installers to use.", - multiple=True, - ) - machine_type = click.core.Option( - ("--machine_type",), - help="Type to make all machines.", - default="n1-standard-4", - ) - self.params.extend([ - image_file, - zone_file, - internal, - machine_type, - installers, - ]) diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py deleted file mode 100644 index 7818d631a..000000000 --- a/benchmarks/runner/runner_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Top-level tests.""" - -import os -import subprocess -import sys - -from click import testing -import pytest - -from benchmarks import runner - - -def _get_locale(): - output = subprocess.check_output(["locale", "-a"]) - locales = output.split() - if b"en_US.utf8" in locales: - return "en_US.UTF-8" - else: - return "C.UTF-8" - - -def _set_locale(): - locale = _get_locale() - if os.getenv("LANG") != locale: - os.environ["LANG"] = locale - os.environ["LC_ALL"] = locale - os.execv("/proc/self/exe", ["python"] + sys.argv) - - -def test_list(): - cli_runner = testing.CliRunner() - result = cli_runner.invoke(runner.runner, ["list"]) - print(result.output) - assert result.exit_code == 0 - - -def test_run(): - cli_runner = testing.CliRunner() - result = cli_runner.invoke(runner.runner, ["run-mock", "."]) - print(result.output) - assert result.exit_code == 0 - - -if __name__ == "__main__": - _set_locale() - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/suites/BUILD b/benchmarks/suites/BUILD deleted file mode 100644 index 04fc23261..000000000 --- a/benchmarks/suites/BUILD +++ /dev/null @@ -1,130 +0,0 @@ -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "suites", - srcs = ["__init__.py"], -) - -py_library( - name = "absl", - srcs = ["absl.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/absl", - ], -) - -py_library( - name = "density", - srcs = ["density.py"], - deps = [ - "//benchmarks/harness:container", - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - ], -) - -py_library( - name = "fio", - srcs = ["fio.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - "//benchmarks/workloads/fio", - ], -) - -py_library( - name = "helpers", - srcs = ["helpers.py"], - deps = ["//benchmarks/harness:machine"], -) - -py_library( - name = "http", - srcs = ["http.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/ab", - ], -) - -py_library( - name = "media", - srcs = ["media.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - "//benchmarks/workloads/ffmpeg", - ], -) - -py_library( - name = "ml", - srcs = ["ml.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:startup", - "//benchmarks/workloads/tensorflow", - ], -) - -py_library( - name = "network", - srcs = ["network.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - "//benchmarks/workloads/iperf", - ], -) - -py_library( - name = "redis", - srcs = ["redis.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/redisbenchmark", - ], -) - -py_library( - name = "startup", - srcs = ["startup.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/suites:helpers", - ], -) - -py_library( - name = "sysbench", - srcs = ["sysbench.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/sysbench", - ], -) - -py_library( - name = "syscall", - srcs = ["syscall.py"], - deps = [ - "//benchmarks/harness:machine", - "//benchmarks/suites", - "//benchmarks/workloads/syscall", - ], -) diff --git a/benchmarks/suites/__init__.py b/benchmarks/suites/__init__.py deleted file mode 100644 index 360736cc3..000000000 --- a/benchmarks/suites/__init__.py +++ /dev/null @@ -1,119 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Core benchmark annotations.""" - -import functools -import inspect -import types -from typing import List -from typing import Tuple - -BENCHMARK_METRICS = '__benchmark_metrics__' -BENCHMARK_MACHINES = '__benchmark_machines__' - - -def is_benchmark(func: types.FunctionType) -> bool: - """Returns true if the given function is a benchmark.""" - return isinstance(func, types.FunctionType) and \ - hasattr(func, BENCHMARK_METRICS) and \ - hasattr(func, BENCHMARK_MACHINES) - - -def benchmark_metrics(func: types.FunctionType) -> List[Tuple[str, str]]: - """Returns the list of available metrics.""" - return [(metric.__name__, metric.__doc__) - for metric in getattr(func, BENCHMARK_METRICS)] - - -def benchmark_machines(func: types.FunctionType) -> int: - """Returns the number of machines required.""" - return getattr(func, BENCHMARK_MACHINES) - - -# pylint: disable=unused-argument -def default(value, **kwargs): - """Returns the passed value.""" - return value - - -def benchmark(metrics: List[types.FunctionType] = None, - machines: int = 1) -> types.FunctionType: - """Define a benchmark function with metrics. - - Args: - metrics: A list of metric functions. - machines: The number of machines required. - - Returns: - A function that accepts the given number of machines, and iteratively - returns a set of (metric_name, metric_value) pairs when called repeatedly. - """ - if not metrics: - # The default passes through. - metrics = [default] - - def decorator(func: types.FunctionType) -> types.FunctionType: - """Decorator function.""" - # Every benchmark should accept at least two parameters: - # runtime: The runtime to use for the benchmark (str, required). - # metrics: The metrics to use, if not the default (str, optional). - @functools.wraps(func) - def wrapper(*args, runtime: str, metric: list = None, **kwargs): - """Wrapper function.""" - # First -- ensure that we marshall all types appropriately. In - # general, we will call this with only strings. These strings will - # need to be converted to their underlying types/classes. - sig = inspect.signature(func) - for param in sig.parameters.values(): - if param.annotation != inspect.Parameter.empty and \ - param.name in kwargs and not isinstance(kwargs[param.name], param.annotation): - try: - # Marshall to the appropriate type. - kwargs[param.name] = param.annotation(kwargs[param.name]) - except Exception as exc: - raise ValueError( - 'illegal type for %s(%s=%s): %s' % - (func.__name__, param.name, kwargs[param.name], exc)) - elif param.default != inspect.Parameter.empty and \ - param.name not in kwargs: - # Ensure that we have the value set, because it will - # be passed to the metric function for evaluation. - kwargs[param.name] = param.default - - # Next, figure out how to apply a metric. We do this prior to - # running the underlying function to prevent having to wait a few - # minutes for a result just to see some error. - if not metric: - # Return all metrics in the iterator. - result = func(*args, runtime=runtime, **kwargs) - for metric_func in metrics: - yield (metric_func.__name__, metric_func(result, **kwargs)) - else: - result = None - for single_metric in metric: - for metric_func in metrics: - # Is this a function that matches the name? - # Apply this function to the result. - if metric_func.__name__ == single_metric: - if not result: - # Lazy evaluation: only if metric matches. - result = func(*args, runtime=runtime, **kwargs) - yield single_metric, metric_func(result, **kwargs) - - # Set metadata on the benchmark (used above). - setattr(wrapper, BENCHMARK_METRICS, metrics) - setattr(wrapper, BENCHMARK_MACHINES, machines) - return wrapper - - return decorator diff --git a/benchmarks/suites/absl.py b/benchmarks/suites/absl.py deleted file mode 100644 index 5d9b57a09..000000000 --- a/benchmarks/suites/absl.py +++ /dev/null @@ -1,37 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""absl build benchmark.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import absl - - -@suites.benchmark(metrics=[absl.elapsed_time], machines=1) -def build(target: machine.Machine, **kwargs) -> str: - """Runs the absl workload and report the absl build time. - - Runs the 'bazel build //absl/...' in a clean bazel directory and - monitors time elapsed. - - Args: - target: A machine object. - **kwargs: Additional container options. - - Returns: - Container output. - """ - image = target.pull("absl") - return target.container(image, **kwargs).run() diff --git a/benchmarks/suites/density.py b/benchmarks/suites/density.py deleted file mode 100644 index 89d29fb26..000000000 --- a/benchmarks/suites/density.py +++ /dev/null @@ -1,121 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Density tests.""" - -import re -import types - -from benchmarks import suites -from benchmarks.harness import container -from benchmarks.harness import machine -from benchmarks.suites import helpers - - -# pylint: disable=unused-argument -def memory_usage(value, **kwargs): - """Returns the passed value.""" - return value - - -def density(target: machine.Machine, - workload: str, - count: int = 50, - wait: float = 0, - load_func: types.FunctionType = None, - **kwargs): - """Calculate the average memory usage per container. - - Args: - target: A machine object. - workload: The workload to run. - count: The number of containers to start. - wait: The time to wait after starting. - load_func: Callback that is called after count images have been started on - the given machine. - **kwargs: Additional container options. - - Returns: - The average usage in Kb per container. - """ - count = int(count) - - # Drop all caches. - helpers.drop_caches(target) - before = target.read("/proc/meminfo") - - # Load the workload. - image = target.pull(workload) - - with target.container( - image=image, count=count, **kwargs).detach() as containers: - # Call the optional load function callback if given. - if load_func: - load_func(target, containers) - # Wait 'wait' time before taking a measurement. - target.sleep(wait) - - # Drop caches again. - helpers.drop_caches(target) - after = target.read("/proc/meminfo") - - # Calculate the memory used. - available_re = re.compile(r"MemAvailable:\s*(\d+)\skB\n") - before_available = available_re.findall(before) - after_available = available_re.findall(after) - return 1024 * float(int(before_available[0]) - - int(after_available[0])) / float(count) - - -def load_redis(target: machine.Machine, containers: container.Container): - """Use redis-benchmark "LPUSH" to load each container with 1G of data. - - Args: - target: A machine object. - containers: A set of containers. - """ - target.pull("redisbenchmark") - for name in containers.get_names(): - flags = "-d 10000 -t LPUSH" - target.container( - "redisbenchmark", links={ - name: name - }).run( - host=name, flags=flags) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def empty(target: machine.Machine, **kwargs) -> float: - """Run trivial containers in a density test.""" - return density(target, workload="sleep", wait=1.0, **kwargs) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def node(target: machine.Machine, **kwargs) -> float: - """Run node containers in a density test.""" - return density(target, workload="node", wait=3.0, **kwargs) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def ruby(target: machine.Machine, **kwargs) -> float: - """Run ruby containers in a density test.""" - return density(target, workload="ruby", wait=3.0, **kwargs) - - -@suites.benchmark(metrics=[memory_usage], machines=1) -def redis(target: machine.Machine, **kwargs) -> float: - """Run redis containers in a density test.""" - if "count" not in kwargs: - kwargs["count"] = 5 - return density( - target, workload="redis", wait=3.0, load_func=load_redis, **kwargs) diff --git a/benchmarks/suites/fio.py b/benchmarks/suites/fio.py deleted file mode 100644 index 2171790c5..000000000 --- a/benchmarks/suites/fio.py +++ /dev/null @@ -1,165 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""File I/O tests.""" - -import os - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers -from benchmarks.workloads import fio - - -# pylint: disable=too-many-arguments -# pylint: disable=too-many-locals -def run_fio(target: machine.Machine, - test: str, - ioengine: str = "sync", - size: int = 1024 * 1024 * 1024, - iodepth: int = 4, - blocksize: int = 1024 * 1024, - time: int = -1, - mount_dir: str = "", - filename: str = "file.dat", - tmpfs: bool = False, - ramp_time: int = 0, - **kwargs) -> str: - """FIO benchmarks. - - For more on fio see: - https://media.readthedocs.org/pdf/fio/latest/fio.pdf - - Args: - target: A machine object. - test: The test to run (read, write, randread, randwrite, etc.) - ioengine: The engine for I/O. - size: The size of the generated file in bytes (if an integer) or 5g, 16k, - etc. - iodepth: The I/O for certain engines. - blocksize: The blocksize for reads and writes in bytes (if an integer) or - 4k, etc. - time: If test is time based, how long to run in seconds. - mount_dir: The absolute path on the host to mount a bind mount. - filename: The name of the file to creat inside container. For a path of - /dir/dir/file, the script setup a volume like 'docker run -v - mount_dir:/dir/dir fio' and fio will create (and delete) the file - /dir/dir/file. If tmpfs is set, this /dir/dir will be a tmpfs. - tmpfs: If true, mount on tmpfs. - ramp_time: The time to run before recording statistics - **kwargs: Additional container options. - - Returns: - The output of fio as a string. - """ - # Pull the image before dropping caches. - image = target.pull("fio") - - if not mount_dir: - stdout, _ = target.run("pwd") - mount_dir = stdout.rstrip() - - # Setup the volumes. - volumes = {mount_dir: {"bind": "/disk", "mode": "rw"}} if not tmpfs else None - tmpfs = {"/disk": ""} if tmpfs else None - - # Construct a file in the volume. - filepath = os.path.join("/disk", filename) - - # If we are running a read test, us fio to write a file and then flush file - # data from memory. - if "read" in test: - target.container( - image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( - test="write", - ioengine="sync", - size=size, - iodepth=iodepth, - blocksize=blocksize, - path=filepath) - helpers.drop_caches(target) - - # Run the test. - time_str = "--time_base --runtime={time}".format( - time=time) if int(time) > 0 else "" - res = target.container( - image, volumes=volumes, tmpfs=tmpfs, **kwargs).run( - test=test, - ioengine=ioengine, - size=size, - iodepth=iodepth, - blocksize=blocksize, - time=time_str, - path=filepath, - ramp_time=ramp_time) - - target.run( - "rm {path}".format(path=os.path.join(mount_dir.rstrip(), filename))) - - return res - - -@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) -def read(*args, **kwargs): - """Read test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="read", **kwargs) - - -@suites.benchmark(metrics=[fio.read_bandwidth, fio.read_io_ops], machines=1) -def randread(*args, **kwargs): - """Random read test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="randread", **kwargs) - - -@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) -def write(*args, **kwargs): - """Write test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="write", **kwargs) - - -@suites.benchmark(metrics=[fio.write_bandwidth, fio.write_io_ops], machines=1) -def randwrite(*args, **kwargs): - """Random write test. - - Args: - *args: None. - **kwargs: Additional container options. - - Returns: - The output of fio. - """ - return run_fio(*args, test="randwrite", **kwargs) diff --git a/benchmarks/suites/helpers.py b/benchmarks/suites/helpers.py deleted file mode 100644 index b3c7360ab..000000000 --- a/benchmarks/suites/helpers.py +++ /dev/null @@ -1,57 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Benchmark helpers.""" - -import datetime -from benchmarks.harness import machine - - -class Timer: - """Helper to time runtime of some call. - - Usage: - - with Timer as t: - # do something. - t.get_time_in_seconds() - """ - - def __init__(self): - self._start = datetime.datetime.now() - - def __enter__(self): - self.start() - return self - - def start(self): - """Starts the timer.""" - self._start = datetime.datetime.now() - - def elapsed(self) -> float: - """Returns the elapsed time in seconds.""" - return (datetime.datetime.now() - self._start).total_seconds() - - def __exit__(self, exception_type, exception_value, exception_traceback): - pass - - -def drop_caches(target: machine.Machine): - """Drops caches on the machine. - - Args: - target: A machine object. - """ - target.run("sudo sync") - target.run("sudo sysctl vm.drop_caches=3") - target.run("sudo sysctl vm.drop_caches=3") diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py deleted file mode 100644 index 6efea938c..000000000 --- a/benchmarks/suites/http.py +++ /dev/null @@ -1,138 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""HTTP benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import ab - - -# pylint: disable=too-many-arguments -def http(server: machine.Machine, - client: machine.Machine, - workload: str, - requests: int = 5000, - connections: int = 10, - port: int = 80, - path: str = "notfound", - **kwargs) -> str: - """Run apachebench (ab) against an http server. - - Args: - server: A machine object. - client: A machine object. - workload: The http-serving workload. - requests: Number of requests to send the server. Default is 5000. - connections: Number of concurent connections to use. Default is 10. - port: The port to access in benchmarking. - path: File to download, generally workload-specific. - **kwargs: Additional container options. - - Returns: - The full apachebench output. - """ - # Pull the client & server. - apachebench = client.pull("ab") - netcat = client.pull("netcat") - image = server.pull(workload) - - with server.container(image, port=port, **kwargs).detach() as container: - (host, port) = container.address() - # Wait for the server to come up. - client.container(netcat).run(host=host, port=port) - # Run the benchmark, no arguments. - return client.container(apachebench).run( - host=host, - port=port, - requests=requests, - connections=connections, - path=path) - - -# pylint: disable=too-many-arguments -# pylint: disable=too-many-locals -def http_app(server: machine.Machine, - client: machine.Machine, - workload: str, - requests: int = 5000, - connections: int = 10, - port: int = 80, - path: str = "notfound", - **kwargs) -> str: - """Run apachebench (ab) against an http application. - - Args: - server: A machine object. - client: A machine object. - workload: The http-serving workload. - requests: Number of requests to send the server. Default is 5000. - connections: Number of concurent connections to use. Default is 10. - port: The port to use for benchmarking. - path: File to download, generally workload-specific. - **kwargs: Additional container options. - - Returns: - The full apachebench output. - """ - # Pull the client & server. - apachebench = client.pull("ab") - netcat = client.pull("netcat") - server_netcat = server.pull("netcat") - redis = server.pull("redis") - image = server.pull(workload) - redis_port = 6379 - redis_name = "{workload}_redis_server".format(workload=workload) - - with server.container(redis, name=redis_name).detach(): - server.container(server_netcat, links={redis_name: redis_name})\ - .run(host=redis_name, port=redis_port) - with server.container(image, port=port, links={redis_name: redis_name}, **kwargs)\ - .detach(host=redis_name) as container: - (host, port) = container.address() - # Wait for the server to come up. - client.container(netcat).run(host=host, port=port) - # Run the benchmark, no arguments. - return client.container(apachebench).run( - host=host, - port=port, - requests=requests, - connections=connections, - path=path) - - -@suites.benchmark(metrics=[ab.transfer_rate, ab.latency], machines=2) -def httpd(*args, **kwargs) -> str: - """Apache2 benchmark.""" - return http(*args, workload="httpd", port=80, **kwargs) - - -@suites.benchmark( - metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) -def nginx(*args, **kwargs) -> str: - """Nginx benchmark.""" - return http(*args, workload="nginx", port=80, **kwargs) - - -@suites.benchmark( - metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) -def node(*args, **kwargs) -> str: - """Node benchmark.""" - return http_app(*args, workload="node_template", path="", port=8080, **kwargs) - - -@suites.benchmark( - metrics=[ab.transfer_rate, ab.latency, ab.requests_per_second], machines=2) -def ruby(*args, **kwargs) -> str: - """Ruby benchmark.""" - return http_app(*args, workload="ruby_template", path="", port=9292, **kwargs) diff --git a/benchmarks/suites/media.py b/benchmarks/suites/media.py deleted file mode 100644 index 9cbffdaa1..000000000 --- a/benchmarks/suites/media.py +++ /dev/null @@ -1,42 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Media processing benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers -from benchmarks.workloads import ffmpeg - - -@suites.benchmark(metrics=[ffmpeg.run_time], machines=1) -def transcode(target: machine.Machine, **kwargs) -> float: - """Runs a video transcoding workload and times it. - - Args: - target: A machine object. - **kwargs: Additional container options. - - Returns: - Total workload runtime. - """ - # Load before timing. - image = target.pull("ffmpeg") - - # Drop caches. - helpers.drop_caches(target) - - # Time startup + transcoding. - with helpers.Timer() as timer: - target.container(image, **kwargs).run() - return timer.elapsed() diff --git a/benchmarks/suites/ml.py b/benchmarks/suites/ml.py deleted file mode 100644 index a394d1f69..000000000 --- a/benchmarks/suites/ml.py +++ /dev/null @@ -1,33 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Machine Learning tests.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import startup -from benchmarks.workloads import tensorflow - - -@suites.benchmark(metrics=[tensorflow.run_time], machines=1) -def train(target: machine.Machine, **kwargs): - """Run the tensorflow benchmark and return the runtime in seconds of workload. - - Args: - target: A machine object. - **kwargs: Additional container options. - - Returns: - The total runtime. - """ - return startup.startup(target, workload="tensorflow", count=1, **kwargs) diff --git a/benchmarks/suites/network.py b/benchmarks/suites/network.py deleted file mode 100644 index f973cf3f1..000000000 --- a/benchmarks/suites/network.py +++ /dev/null @@ -1,101 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Network microbenchmarks.""" - -from typing import Dict - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers -from benchmarks.workloads import iperf - - -def run_iperf(client: machine.Machine, - server: machine.Machine, - client_kwargs: Dict[str, str] = None, - server_kwargs: Dict[str, str] = None) -> str: - """Measure iperf performance. - - Args: - client: A machine object. - server: A machine object. - client_kwargs: Additional client container options. - server_kwargs: Additional server container options. - - Returns: - The output of iperf. - """ - if not client_kwargs: - client_kwargs = dict() - if not server_kwargs: - server_kwargs = dict() - - # Pull images. - netcat = client.pull("netcat") - iperf_client_image = client.pull("iperf") - iperf_server_image = server.pull("iperf") - - # Set this due to a bug in the kernel that resets connections. - client.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") - server.run("sudo /sbin/sysctl -w net.netfilter.nf_conntrack_tcp_be_liberal=1") - - with server.container( - iperf_server_image, port=5001, **server_kwargs).detach() as iperf_server: - (host, port) = iperf_server.address() - # Wait until the service is available. - client.container(netcat).run(host=host, port=port) - # Run a warm-up run. - client.container( - iperf_client_image, stderr=True, **client_kwargs).run( - host=host, port=port) - # Run the client with relevant arguments. - res = client.container(iperf_client_image, stderr=True, **client_kwargs)\ - .run(host=host, port=port) - helpers.drop_caches(client) - return res - - -@suites.benchmark(metrics=[iperf.bandwidth], machines=2) -def upload(client: machine.Machine, server: machine.Machine, **kwargs) -> str: - """Measure upload performance. - - Args: - client: A machine object. - server: A machine object. - **kwargs: Client container options. - - Returns: - The output of iperf. - """ - if kwargs["runtime"] == "runc": - kwargs["network_mode"] = "host" - return run_iperf(client, server, client_kwargs=kwargs) - - -@suites.benchmark(metrics=[iperf.bandwidth], machines=2) -def download(client: machine.Machine, server: machine.Machine, **kwargs) -> str: - """Measure download performance. - - Args: - client: A machine object. - server: A machine object. - **kwargs: Server container options. - - Returns: - The output of iperf. - """ - - client_kwargs = {"network_mode": "host"} - return run_iperf( - client, server, client_kwargs=client_kwargs, server_kwargs=kwargs) diff --git a/benchmarks/suites/redis.py b/benchmarks/suites/redis.py deleted file mode 100644 index b84dd073d..000000000 --- a/benchmarks/suites/redis.py +++ /dev/null @@ -1,46 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Redis benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import redisbenchmark - - -@suites.benchmark(metrics=list(redisbenchmark.METRICS.values()), machines=2) -def redis(server: machine.Machine, - client: machine.Machine, - flags: str = "", - **kwargs) -> str: - """Run redis-benchmark on client pointing at server machine. - - Args: - server: A machine object. - client: A machine object. - flags: Flags to pass redis-benchmark. - **kwargs: Additional container options. - - Returns: - Output from redis-benchmark. - """ - redis_server = server.pull("redis") - redis_client = client.pull("redisbenchmark") - netcat = client.pull("netcat") - with server.container( - redis_server, port=6379, **kwargs).detach() as container: - (host, port) = container.address() - # Wait for the container to be up. - client.container(netcat).run(host=host, port=port) - # Run all redis benchmarks. - return client.container(redis_client).run(host=host, port=port, flags=flags) diff --git a/benchmarks/suites/startup.py b/benchmarks/suites/startup.py deleted file mode 100644 index a1b6c5753..000000000 --- a/benchmarks/suites/startup.py +++ /dev/null @@ -1,110 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Start-up benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.suites import helpers - - -# pylint: disable=unused-argument -def startup_time_ms(value, **kwargs): - """Returns average startup time per container in milliseconds. - - Args: - value: The floating point time in seconds. - **kwargs: Ignored. - - Returns: - The time given in milliseconds. - """ - return value * 1000 - - -def startup(target: machine.Machine, - workload: str, - count: int = 5, - port: int = 0, - **kwargs): - """Time the startup of some workload. - - Args: - target: A machine object. - workload: The workload to run. - count: Number of containers to start. - port: The port to check for liveness, if provided. - **kwargs: Additional container options. - - Returns: - The mean start-up time in seconds. - """ - # Load before timing. - image = target.pull(workload) - netcat = target.pull("netcat") - count = int(count) - port = int(port) - - with helpers.Timer() as timer: - for _ in range(count): - if not port: - # Run the container synchronously. - target.container(image, **kwargs).run() - else: - # Run a detached container until httpd available. - with target.container(image, port=port, **kwargs).detach() as server: - (server_host, server_port) = server.address() - target.container(netcat).run(host=server_host, port=server_port) - return timer.elapsed() / float(count) - - -@suites.benchmark(metrics=[startup_time_ms], machines=1) -def empty(target: machine.Machine, **kwargs) -> float: - """Time the startup of a trivial container. - - Args: - target: A machine object. - **kwargs: Additional startup options. - - Returns: - The time to run the container. - """ - return startup(target, workload="true", **kwargs) - - -@suites.benchmark(metrics=[startup_time_ms], machines=1) -def node(target: machine.Machine, **kwargs) -> float: - """Time the startup of the node container. - - Args: - target: A machine object. - **kwargs: Additional statup options. - - Returns: - The time to run the container. - """ - return startup(target, workload="node", port=8080, **kwargs) - - -@suites.benchmark(metrics=[startup_time_ms], machines=1) -def ruby(target: machine.Machine, **kwargs) -> float: - """Time the startup of the ruby container. - - Args: - target: A machine object. - **kwargs: Additional startup options. - - Returns: - The time to run the container. - """ - return startup(target, workload="ruby", port=3000, **kwargs) diff --git a/benchmarks/suites/sysbench.py b/benchmarks/suites/sysbench.py deleted file mode 100644 index 2a6e2126c..000000000 --- a/benchmarks/suites/sysbench.py +++ /dev/null @@ -1,119 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Sysbench-based benchmarks.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads import sysbench - - -def run_sysbench(target: machine.Machine, - test: str = "cpu", - threads: int = 8, - time: int = 5, - options: str = "", - **kwargs) -> str: - """Run sysbench container with arguments. - - Args: - target: A machine object. - test: Relevant sysbench test to run (e.g. cpu, memory). - threads: The number of threads to use for tests. - time: The time to run tests. - options: Additional sysbench options. - **kwargs: Additional container options. - - Returns: - The output of the command as a string. - """ - image = target.pull("sysbench") - return target.container(image, **kwargs).run( - test=test, threads=threads, time=time, options=options) - - -@suites.benchmark(metrics=[sysbench.cpu_events_per_second], machines=1) -def cpu(target: machine.Machine, max_prime: int = 5000, **kwargs) -> str: - """Run sysbench CPU test. - - Additional arguments can be provided for sysbench. - - Args: - target: A machine object. - max_prime: The maximum prime number to search. - **kwargs: - - threads: The number of threads to use for tests. - - time: The time to run tests. - - options: Additional sysbench options. See sysbench tool: - https://github.com/akopytov/sysbench - - Returns: - Sysbench output. - """ - options = kwargs.pop("options", "") - options += " --cpu-max-prime={}".format(max_prime) - return run_sysbench(target, test="cpu", options=options, **kwargs) - - -@suites.benchmark(metrics=[sysbench.memory_ops_per_second], machines=1) -def memory(target: machine.Machine, **kwargs) -> str: - """Run sysbench memory test. - - Additional arguments can be provided per sysbench. - - Args: - target: A machine object. - **kwargs: - - threads: The number of threads to use for tests. - - time: The time to run tests. - - options: Additional sysbench options. See sysbench tool: - https://github.com/akopytov/sysbench - - Returns: - Sysbench output. - """ - return run_sysbench(target, test="memory", **kwargs) - - -@suites.benchmark( - metrics=[ - sysbench.mutex_time, sysbench.mutex_latency, sysbench.mutex_deviation - ], - machines=1) -def mutex(target: machine.Machine, - locks: int = 4, - count: int = 10000000, - threads: int = 8, - **kwargs) -> str: - """Run sysbench mutex test. - - Additional arguments can be provided per sysbench. - - Args: - target: A machine object. - locks: The number of locks to use. - count: The number of mutexes. - threads: The number of threads to use for tests. - **kwargs: - - time: The time to run tests. - - options: Additional sysbench options. See sysbench tool: - https://github.com/akopytov/sysbench - - Returns: - Sysbench output. - """ - options = kwargs.pop("options", "") - options += " --mutex-loops=1 --mutex-locks={} --mutex-num={}".format( - count, locks) - return run_sysbench( - target, test="mutex", options=options, threads=threads, **kwargs) diff --git a/benchmarks/suites/syscall.py b/benchmarks/suites/syscall.py deleted file mode 100644 index fa7665b00..000000000 --- a/benchmarks/suites/syscall.py +++ /dev/null @@ -1,37 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Syscall microbenchmark.""" - -from benchmarks import suites -from benchmarks.harness import machine -from benchmarks.workloads.syscall import syscall_time_ns - - -@suites.benchmark(metrics=[syscall_time_ns], machines=1) -def syscall(target: machine.Machine, count: int = 1000000, **kwargs) -> str: - """Runs the syscall workload and report the syscall time. - - Runs the syscall 'SYS_gettimeofday(0,0)' 'count' times and monitors time - elapsed based on the runtime's MONOTONIC clock. - - Args: - target: A machine object. - count: The number of syscalls to execute. - **kwargs: Additional container options. - - Returns: - Container output. - """ - image = target.pull("syscall") - return target.container(image, **kwargs).run(count=count) diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD deleted file mode 100644 index ccb86af5b..000000000 --- a/benchmarks/workloads/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "workloads", - srcs = ["__init__.py"], -) - -filegroup( - name = "files", - srcs = [ - "//benchmarks/workloads/ab:tar", - "//benchmarks/workloads/absl:tar", - "//benchmarks/workloads/curl:tar", - "//benchmarks/workloads/ffmpeg:tar", - "//benchmarks/workloads/fio:tar", - "//benchmarks/workloads/httpd:tar", - "//benchmarks/workloads/iperf:tar", - "//benchmarks/workloads/netcat:tar", - "//benchmarks/workloads/nginx:tar", - "//benchmarks/workloads/node:tar", - "//benchmarks/workloads/node_template:tar", - "//benchmarks/workloads/redis:tar", - "//benchmarks/workloads/redisbenchmark:tar", - "//benchmarks/workloads/ruby:tar", - "//benchmarks/workloads/ruby_template:tar", - "//benchmarks/workloads/sleep:tar", - "//benchmarks/workloads/sysbench:tar", - "//benchmarks/workloads/syscall:tar", - "//benchmarks/workloads/tensorflow:tar", - "//benchmarks/workloads/true:tar", - ], -) diff --git a/benchmarks/workloads/__init__.py b/benchmarks/workloads/__init__.py deleted file mode 100644 index e12651e76..000000000 --- a/benchmarks/workloads/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Workloads, parsers and test data.""" diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD deleted file mode 100644 index 945ac7026..000000000 --- a/benchmarks/workloads/ab/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "ab", - srcs = ["__init__.py"], -) - -py_test( - name = "ab_test", - srcs = ["ab_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":ab", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/ab/Dockerfile b/benchmarks/workloads/ab/Dockerfile deleted file mode 100644 index 0d0b6e2eb..000000000 --- a/benchmarks/workloads/ab/Dockerfile +++ /dev/null @@ -1,15 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - apache2-utils \ - && rm -rf /var/lib/apt/lists/* - -# Parameterized workload. -ENV requests 5000 -ENV connections 10 -ENV host localhost -ENV port 8080 -ENV path notfound -CMD ["sh", "-c", "ab -n ${requests} -c ${connections} http://${host}:${port}/${path}"] diff --git a/benchmarks/workloads/ab/__init__.py b/benchmarks/workloads/ab/__init__.py deleted file mode 100644 index eedf8e083..000000000 --- a/benchmarks/workloads/ab/__init__.py +++ /dev/null @@ -1,88 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Apachebench tool.""" - -import re - -SAMPLE_DATA = """This is ApacheBench, Version 2.3 <$Revision: 1826891 $> -Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ -Licensed to The Apache Software Foundation, http://www.apache.org/ - -Benchmarking 10.10.10.10 (be patient).....done - - -Server Software: Apache/2.4.38 -Server Hostname: 10.10.10.10 -Server Port: 80 - -Document Path: /latin10k.txt -Document Length: 210 bytes - -Concurrency Level: 1 -Time taken for tests: 0.180 seconds -Complete requests: 100 -Failed requests: 0 -Non-2xx responses: 100 -Total transferred: 38800 bytes -HTML transferred: 21000 bytes -Requests per second: 556.44 [#/sec] (mean) -Time per request: 1.797 [ms] (mean) -Time per request: 1.797 [ms] (mean, across all concurrent requests) -Transfer rate: 210.84 [Kbytes/sec] received - -Connection Times (ms) - min mean[+/-sd] median max -Connect: 0 0 0.2 0 2 -Processing: 1 2 1.0 1 8 -Waiting: 1 1 1.0 1 7 -Total: 1 2 1.2 1 10 - -Percentage of the requests served within a certain time (ms) - 50% 1 - 66% 2 - 75% 2 - 80% 2 - 90% 2 - 95% 3 - 98% 7 - 99% 10 - 100% 10 (longest request)""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def transfer_rate(data: str, **kwargs) -> float: - """Mean transfer rate in Kbytes/sec.""" - regex = r"Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received" - return float(re.compile(regex).search(data).group(1)) - - -# pylint: disable=unused-argument -def latency(data: str, **kwargs) -> float: - """Mean latency in milliseconds.""" - regex = r"Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s" - res = re.compile(regex).search(data) - return float(res.group(1)) - - -# pylint: disable=unused-argument -def requests_per_second(data: str, **kwargs) -> float: - """Requests per second.""" - regex = r"Requests per second:\s+(\d+\.?\d+?)\s+" - res = re.compile(regex).search(data) - return float(res.group(1)) diff --git a/benchmarks/workloads/ab/ab_test.py b/benchmarks/workloads/ab/ab_test.py deleted file mode 100644 index 4afac2996..000000000 --- a/benchmarks/workloads/ab/ab_test.py +++ /dev/null @@ -1,42 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser test.""" - -import sys - -import pytest - -from benchmarks.workloads import ab - - -def test_transfer_rate_parser(): - """Test transfer rate parser.""" - res = ab.transfer_rate(ab.sample()) - assert res == 210.84 - - -def test_latency_parser(): - """Test latency parser.""" - res = ab.latency(ab.sample()) - assert res == 2 - - -def test_requests_per_second(): - """Test requests per second parser.""" - res = ab.requests_per_second(ab.sample()) - assert res == 556.44 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD deleted file mode 100644 index bb1a308bf..000000000 --- a/benchmarks/workloads/absl/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "absl", - srcs = ["__init__.py"], -) - -py_test( - name = "absl_test", - srcs = ["absl_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":absl", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/absl/__init__.py b/benchmarks/workloads/absl/__init__.py deleted file mode 100644 index b40e3f915..000000000 --- a/benchmarks/workloads/absl/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ABSL build benchmark.""" - -import re - -SAMPLE_BAZEL_OUTPUT = """Extracting Bazel installation... -Starting local Bazel server and connecting to it... -Loading: -Loading: 0 packages loaded -Loading: 0 packages loaded - currently loading: absl/algorithm ... (11 packages) -Analyzing: 241 targets (16 packages loaded, 0 targets configured) -Analyzing: 241 targets (21 packages loaded, 617 targets configured) -Analyzing: 241 targets (27 packages loaded, 687 targets configured) -Analyzing: 241 targets (32 packages loaded, 1105 targets configured) -Analyzing: 241 targets (32 packages loaded, 1294 targets configured) -Analyzing: 241 targets (35 packages loaded, 1575 targets configured) -Analyzing: 241 targets (35 packages loaded, 1575 targets configured) -Analyzing: 241 targets (36 packages loaded, 1603 targets configured) -Analyzing: 241 targets (36 packages loaded, 1603 targets configured) -INFO: Analyzed 241 targets (37 packages loaded, 1864 targets configured). -INFO: Found 241 targets... -[0 / 5] [Prepa] BazelWorkspaceStatusAction stable-status.txt -[16 / 50] [Analy] Compiling absl/base/dynamic_annotations.cc ... (20 actions, 10 running) -[60 / 77] Compiling external/com_google_googletest/googletest/src/gtest.cc; 5s processwrapper-sandbox ... (12 actions, 11 running) -[158 / 174] Compiling absl/container/internal/raw_hash_set_test.cc; 2s processwrapper-sandbox ... (12 actions, 11 running) -[278 / 302] Compiling absl/container/internal/raw_hash_set_test.cc; 6s processwrapper-sandbox ... (12 actions, 11 running) -[384 / 406] Compiling absl/container/internal/raw_hash_set_test.cc; 10s processwrapper-sandbox ... (12 actions, 11 running) -[581 / 604] Compiling absl/container/flat_hash_set_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) -[722 / 745] Compiling absl/container/node_hash_set_test.cc; 9s processwrapper-sandbox ... (12 actions, 11 running) -[846 / 867] Compiling absl/hash/hash_test.cc; 11s processwrapper-sandbox ... (12 actions, 11 running) -INFO: From Compiling absl/debugging/symbolize_test.cc: -/tmp/cclCVipU.s: Assembler messages: -/tmp/cclCVipU.s:1662: Warning: ignoring changed section attributes for .text -[999 / 1,022] Compiling absl/hash/hash_test.cc; 19s processwrapper-sandbox ... (12 actions, 11 running) -[1,082 / 1,084] Compiling absl/container/flat_hash_map_test.cc; 7s processwrapper-sandbox -INFO: Elapsed time: 81.861s, Critical Path: 23.81s -INFO: 515 processes: 515 processwrapper-sandbox. -INFO: Build completed successfully, 1084 total actions -INFO: Build completed successfully, 1084 total actions""" - - -def sample(): - return SAMPLE_BAZEL_OUTPUT - - -# pylint: disable=unused-argument -def elapsed_time(data: str, **kwargs) -> float: - """Returns the elapsed time for running an absl build.""" - return float(re.compile(r"Elapsed time: (\d*.?\d*)s").search(data).group(1)) diff --git a/benchmarks/workloads/absl/absl_test.py b/benchmarks/workloads/absl/absl_test.py deleted file mode 100644 index 41f216999..000000000 --- a/benchmarks/workloads/absl/absl_test.py +++ /dev/null @@ -1,31 +0,0 @@ -# python3 -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ABSL build test.""" - -import sys - -import pytest - -from benchmarks.workloads import absl - - -def test_elapsed_time(): - """Test elapsed_time.""" - res = absl.elapsed_time(absl.sample()) - assert res == 81.861 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/curl/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/curl/Dockerfile b/benchmarks/workloads/curl/Dockerfile deleted file mode 100644 index 336cb088a..000000000 --- a/benchmarks/workloads/curl/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Accept a host and port parameter. -ENV host localhost -ENV port 8080 - -# Spin until we make a successful request. -CMD ["sh", "-c", "while ! curl -v -i http://$host:$port; do true; done"] diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD deleted file mode 100644 index 7c41ba631..000000000 --- a/benchmarks/workloads/ffmpeg/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "ffmpeg", - srcs = ["__init__.py"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/ffmpeg/__init__.py b/benchmarks/workloads/ffmpeg/__init__.py deleted file mode 100644 index 7578a443b..000000000 --- a/benchmarks/workloads/ffmpeg/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Simple ffmpeg workload.""" - - -# pylint: disable=unused-argument -def run_time(value, **kwargs): - """Returns the startup and runtime of the ffmpeg workload in seconds.""" - return value diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD deleted file mode 100644 index 24d909c53..000000000 --- a/benchmarks/workloads/fio/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "fio", - srcs = ["__init__.py"], -) - -py_test( - name = "fio_test", - srcs = ["fio_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":fio", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/fio/Dockerfile b/benchmarks/workloads/fio/Dockerfile deleted file mode 100644 index b3cf864eb..000000000 --- a/benchmarks/workloads/fio/Dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - fio \ - && rm -rf /var/lib/apt/lists/* - -# Parameterized test. -ENV test write -ENV ioengine sync -ENV size 5000000 -ENV iodepth 4 -ENV blocksize "1m" -ENV time "" -ENV path "/disk/file.dat" -ENV ramp_time 0 - -CMD ["sh", "-c", "fio --output-format=json --name=test --ramp_time=${ramp_time} --ioengine=${ioengine} --size=${size} \ ---filename=${path} --iodepth=${iodepth} --bs=${blocksize} --rw=${test} ${time}"] - - - diff --git a/benchmarks/workloads/fio/__init__.py b/benchmarks/workloads/fio/__init__.py deleted file mode 100644 index 52711e956..000000000 --- a/benchmarks/workloads/fio/__init__.py +++ /dev/null @@ -1,369 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""FIO benchmark tool.""" - -import json - -SAMPLE_DATA = """ -{ - "fio version" : "fio-3.1", - "timestamp" : 1554837456, - "timestamp_ms" : 1554837456621, - "time" : "Tue Apr 9 19:17:36 2019", - "jobs" : [ - { - "jobname" : "test", - "groupid" : 0, - "error" : 0, - "eta" : 2147483647, - "elapsed" : 1, - "job options" : { - "name" : "test", - "ioengine" : "sync", - "size" : "1073741824", - "filename" : "/disk/file.dat", - "iodepth" : "4", - "bs" : "4096", - "rw" : "write" - }, - "read" : { - "io_bytes" : 0, - "io_kbytes" : 0, - "bw" : 0, - "iops" : 0.000000, - "runtime" : 0, - "total_ios" : 0, - "short_ios" : 0, - "drop_ios" : 0, - "slat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "clat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000, - "percentile" : { - "1.000000" : 0, - "5.000000" : 0, - "10.000000" : 0, - "20.000000" : 0, - "30.000000" : 0, - "40.000000" : 0, - "50.000000" : 0, - "60.000000" : 0, - "70.000000" : 0, - "80.000000" : 0, - "90.000000" : 0, - "95.000000" : 0, - "99.000000" : 0, - "99.500000" : 0, - "99.900000" : 0, - "99.950000" : 0, - "99.990000" : 0, - "0.00" : 0, - "0.00" : 0, - "0.00" : 0 - } - }, - "lat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "bw_min" : 0, - "bw_max" : 0, - "bw_agg" : 0.000000, - "bw_mean" : 0.000000, - "bw_dev" : 0.000000, - "bw_samples" : 0, - "iops_min" : 0, - "iops_max" : 0, - "iops_mean" : 0.000000, - "iops_stddev" : 0.000000, - "iops_samples" : 0 - }, - "write" : { - "io_bytes" : 1073741824, - "io_kbytes" : 1048576, - "bw" : 1753471, - "iops" : 438367.892977, - "runtime" : 598, - "total_ios" : 262144, - "short_ios" : 0, - "drop_ios" : 0, - "slat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "clat_ns" : { - "min" : 1693, - "max" : 754733, - "mean" : 2076.404373, - "stddev" : 1724.195529, - "percentile" : { - "1.000000" : 1736, - "5.000000" : 1752, - "10.000000" : 1768, - "20.000000" : 1784, - "30.000000" : 1800, - "40.000000" : 1800, - "50.000000" : 1816, - "60.000000" : 1816, - "70.000000" : 1848, - "80.000000" : 1928, - "90.000000" : 2512, - "95.000000" : 2992, - "99.000000" : 6176, - "99.500000" : 6304, - "99.900000" : 11328, - "99.950000" : 15168, - "99.990000" : 17792, - "0.00" : 0, - "0.00" : 0, - "0.00" : 0 - } - }, - "lat_ns" : { - "min" : 1731, - "max" : 754770, - "mean" : 2117.878979, - "stddev" : 1730.290512 - }, - "bw_min" : 1731120, - "bw_max" : 1731120, - "bw_agg" : 98.725328, - "bw_mean" : 1731120.000000, - "bw_dev" : 0.000000, - "bw_samples" : 1, - "iops_min" : 432780, - "iops_max" : 432780, - "iops_mean" : 432780.000000, - "iops_stddev" : 0.000000, - "iops_samples" : 1 - }, - "trim" : { - "io_bytes" : 0, - "io_kbytes" : 0, - "bw" : 0, - "iops" : 0.000000, - "runtime" : 0, - "total_ios" : 0, - "short_ios" : 0, - "drop_ios" : 0, - "slat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "clat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000, - "percentile" : { - "1.000000" : 0, - "5.000000" : 0, - "10.000000" : 0, - "20.000000" : 0, - "30.000000" : 0, - "40.000000" : 0, - "50.000000" : 0, - "60.000000" : 0, - "70.000000" : 0, - "80.000000" : 0, - "90.000000" : 0, - "95.000000" : 0, - "99.000000" : 0, - "99.500000" : 0, - "99.900000" : 0, - "99.950000" : 0, - "99.990000" : 0, - "0.00" : 0, - "0.00" : 0, - "0.00" : 0 - } - }, - "lat_ns" : { - "min" : 0, - "max" : 0, - "mean" : 0.000000, - "stddev" : 0.000000 - }, - "bw_min" : 0, - "bw_max" : 0, - "bw_agg" : 0.000000, - "bw_mean" : 0.000000, - "bw_dev" : 0.000000, - "bw_samples" : 0, - "iops_min" : 0, - "iops_max" : 0, - "iops_mean" : 0.000000, - "iops_stddev" : 0.000000, - "iops_samples" : 0 - }, - "usr_cpu" : 17.922948, - "sys_cpu" : 81.574539, - "ctx" : 3, - "majf" : 0, - "minf" : 10, - "iodepth_level" : { - "1" : 100.000000, - "2" : 0.000000, - "4" : 0.000000, - "8" : 0.000000, - "16" : 0.000000, - "32" : 0.000000, - ">=64" : 0.000000 - }, - "latency_ns" : { - "2" : 0.000000, - "4" : 0.000000, - "10" : 0.000000, - "20" : 0.000000, - "50" : 0.000000, - "100" : 0.000000, - "250" : 0.000000, - "500" : 0.000000, - "750" : 0.000000, - "1000" : 0.000000 - }, - "latency_us" : { - "2" : 82.737350, - "4" : 12.605286, - "10" : 4.543686, - "20" : 0.107956, - "50" : 0.010000, - "100" : 0.000000, - "250" : 0.000000, - "500" : 0.000000, - "750" : 0.000000, - "1000" : 0.010000 - }, - "latency_ms" : { - "2" : 0.000000, - "4" : 0.000000, - "10" : 0.000000, - "20" : 0.000000, - "50" : 0.000000, - "100" : 0.000000, - "250" : 0.000000, - "500" : 0.000000, - "750" : 0.000000, - "1000" : 0.000000, - "2000" : 0.000000, - ">=2000" : 0.000000 - }, - "latency_depth" : 4, - "latency_target" : 0, - "latency_percentile" : 100.000000, - "latency_window" : 0 - } - ], - "disk_util" : [ - { - "name" : "dm-1", - "read_ios" : 0, - "write_ios" : 3, - "read_merges" : 0, - "write_merges" : 0, - "read_ticks" : 0, - "write_ticks" : 0, - "in_queue" : 0, - "util" : 0.000000, - "aggr_read_ios" : 0, - "aggr_write_ios" : 3, - "aggr_read_merges" : 0, - "aggr_write_merge" : 0, - "aggr_read_ticks" : 0, - "aggr_write_ticks" : 0, - "aggr_in_queue" : 0, - "aggr_util" : 0.000000 - }, - { - "name" : "dm-0", - "read_ios" : 0, - "write_ios" : 3, - "read_merges" : 0, - "write_merges" : 0, - "read_ticks" : 0, - "write_ticks" : 0, - "in_queue" : 0, - "util" : 0.000000, - "aggr_read_ios" : 0, - "aggr_write_ios" : 3, - "aggr_read_merges" : 0, - "aggr_write_merge" : 0, - "aggr_read_ticks" : 0, - "aggr_write_ticks" : 2, - "aggr_in_queue" : 0, - "aggr_util" : 0.000000 - }, - { - "name" : "nvme0n1", - "read_ios" : 0, - "write_ios" : 3, - "read_merges" : 0, - "write_merges" : 0, - "read_ticks" : 0, - "write_ticks" : 2, - "in_queue" : 0, - "util" : 0.000000 - } - ] -} -""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def read_bandwidth(data: str, **kwargs) -> int: - """File I/O bandwidth.""" - return json.loads(data)["jobs"][0]["read"]["bw"] * 1024 - - -# pylint: disable=unused-argument -def write_bandwidth(data: str, **kwargs) -> int: - """File I/O bandwidth.""" - return json.loads(data)["jobs"][0]["write"]["bw"] * 1024 - - -# pylint: disable=unused-argument -def read_io_ops(data: str, **kwargs) -> float: - """File I/O operations per second.""" - return float(json.loads(data)["jobs"][0]["read"]["iops"]) - - -# pylint: disable=unused-argument -def write_io_ops(data: str, **kwargs) -> float: - """File I/O operations per second.""" - return float(json.loads(data)["jobs"][0]["write"]["iops"]) - - -# Change function names so we just print "bandwidth" and "io_ops". -read_bandwidth.__name__ = "bandwidth" -write_bandwidth.__name__ = "bandwidth" -read_io_ops.__name__ = "io_ops" -write_io_ops.__name__ = "io_ops" diff --git a/benchmarks/workloads/fio/fio_test.py b/benchmarks/workloads/fio/fio_test.py deleted file mode 100644 index 04a6eeb7e..000000000 --- a/benchmarks/workloads/fio/fio_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser tests.""" - -import sys - -import pytest - -from benchmarks.workloads import fio - - -def test_read_io_ops(): - """Test read ops parser.""" - assert fio.read_io_ops(fio.sample()) == 0.0 - - -def test_write_io_ops(): - """Test write ops parser.""" - assert fio.write_io_ops(fio.sample()) == 438367.892977 - - -def test_read_bandwidth(): - """Test read bandwidth parser.""" - assert fio.read_bandwidth(fio.sample()) == 0.0 - - -def test_write_bandwith(): - """Test write bandwidth parser.""" - assert fio.write_bandwidth(fio.sample()) == 1753471 * 1024 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD deleted file mode 100644 index 83450d190..000000000 --- a/benchmarks/workloads/httpd/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "apache2-tmpdir.conf", - ], -) diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD deleted file mode 100644 index 91b953718..000000000 --- a/benchmarks/workloads/iperf/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "iperf", - srcs = ["__init__.py"], -) - -py_test( - name = "iperf_test", - srcs = ["iperf_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":iperf", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/iperf/Dockerfile b/benchmarks/workloads/iperf/Dockerfile deleted file mode 100644 index 9704c506c..000000000 --- a/benchmarks/workloads/iperf/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - iperf \ - && rm -rf /var/lib/apt/lists/* - -# Accept a host parameter. -ENV host "" -ENV port 5001 - -# Start as client if the host is provided. -CMD ["sh", "-c", "test -z \"${host}\" && iperf -s || iperf -f K --realtime -c ${host} -p ${port}"] diff --git a/benchmarks/workloads/iperf/__init__.py b/benchmarks/workloads/iperf/__init__.py deleted file mode 100644 index 3817a7ade..000000000 --- a/benchmarks/workloads/iperf/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""iperf.""" - -import re - -SAMPLE_DATA = """ ------------------------------------------------------------- -Client connecting to 10.138.15.215, TCP port 32779 -TCP window size: 45.0 KByte (default) ------------------------------------------------------------- -[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 -[ ID] Interval Transfer Bandwidth -[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec - -""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def bandwidth(data: str, **kwargs) -> float: - """Calculate the bandwidth.""" - regex = r"\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec" - res = re.compile(regex).search(data) - return float(res.group(1)) * 1000 diff --git a/benchmarks/workloads/iperf/iperf_test.py b/benchmarks/workloads/iperf/iperf_test.py deleted file mode 100644 index 6959b7e8a..000000000 --- a/benchmarks/workloads/iperf/iperf_test.py +++ /dev/null @@ -1,28 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for iperf.""" - -import sys - -import pytest - -from benchmarks.workloads import iperf - - -def test_bandwidth(): - assert iperf.bandwidth(iperf.sample()) == 45900 * 1000 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/netcat/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/netcat/Dockerfile b/benchmarks/workloads/netcat/Dockerfile deleted file mode 100644 index d8548d89a..000000000 --- a/benchmarks/workloads/netcat/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - netcat \ - && rm -rf /var/lib/apt/lists/* - -# Accept a host and port parameter. -ENV host localhost -ENV port 8080 - -# Spin until we make a successful request. -CMD ["sh", "-c", "while ! nc -zv $host $port; do true; done"] diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/nginx/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/nginx/Dockerfile b/benchmarks/workloads/nginx/Dockerfile deleted file mode 100644 index b64eb52ae..000000000 --- a/benchmarks/workloads/nginx/Dockerfile +++ /dev/null @@ -1 +0,0 @@ -FROM nginx:1.15.10 diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD deleted file mode 100644 index bfcf78cf9..000000000 --- a/benchmarks/workloads/node/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "index.js", - "package.json", - ], -) diff --git a/benchmarks/workloads/node/Dockerfile b/benchmarks/workloads/node/Dockerfile deleted file mode 100644 index 139a38bf5..000000000 --- a/benchmarks/workloads/node/Dockerfile +++ /dev/null @@ -1,2 +0,0 @@ -FROM node:onbuild -CMD ["node", "index.js"] diff --git a/benchmarks/workloads/node/index.js b/benchmarks/workloads/node/index.js deleted file mode 100644 index 584158462..000000000 --- a/benchmarks/workloads/node/index.js +++ /dev/null @@ -1,28 +0,0 @@ -'use strict'; - -var start = new Date().getTime(); - -// Load dependencies to simulate an average nodejs app. -var req_0 = require('async'); -var req_1 = require('bluebird'); -var req_2 = require('firebase'); -var req_3 = require('firebase-admin'); -var req_4 = require('@google-cloud/container'); -var req_5 = require('@google-cloud/logging'); -var req_6 = require('@google-cloud/monitoring'); -var req_7 = require('@google-cloud/spanner'); -var req_8 = require('lodash'); -var req_9 = require('mailgun-js'); -var req_10 = require('request'); -var express = require('express'); -var app = express(); - -var loaded = new Date().getTime() - start; -app.get('/', function(req, res) { - res.send('Hello World!<br>Loaded in ' + loaded + 'ms'); -}); - -console.log('Loaded in ' + loaded + ' ms'); -app.listen(8080, function() { - console.log('Listening on port 8080...'); -}); diff --git a/benchmarks/workloads/node/package.json b/benchmarks/workloads/node/package.json deleted file mode 100644 index c00b9b3cb..000000000 --- a/benchmarks/workloads/node/package.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "name": "node", - "version": "1.0.0", - "main": "index.js", - "dependencies": { - "@google-cloud/container": "^0.3.0", - "@google-cloud/logging": "^4.2.0", - "@google-cloud/monitoring": "^0.6.0", - "@google-cloud/spanner": "^2.2.1", - "async": "^2.6.1", - "bluebird": "^3.5.3", - "express": "^4.16.4", - "firebase": "^5.7.2", - "firebase-admin": "^6.4.0", - "lodash": "^4.17.11", - "mailgun-js": "^0.22.0", - "request": "^2.88.0" - } -} diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD deleted file mode 100644 index e142f082a..000000000 --- a/benchmarks/workloads/node_template/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "index.hbs", - "index.js", - "package.json", - "package-lock.json", - ], -) diff --git a/benchmarks/workloads/node_template/Dockerfile b/benchmarks/workloads/node_template/Dockerfile deleted file mode 100644 index 7eb065d54..000000000 --- a/benchmarks/workloads/node_template/Dockerfile +++ /dev/null @@ -1,5 +0,0 @@ -FROM node:onbuild - -ENV host "127.0.0.1" - -CMD ["sh", "-c", "node index.js ${host}"] diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/redis/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD deleted file mode 100644 index 147cfedd2..000000000 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "redisbenchmark", - srcs = ["__init__.py"], -) - -py_test( - name = "redisbenchmark_test", - srcs = ["redisbenchmark_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":redisbenchmark", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/redisbenchmark/Dockerfile b/benchmarks/workloads/redisbenchmark/Dockerfile deleted file mode 100644 index f94f6442e..000000000 --- a/benchmarks/workloads/redisbenchmark/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -FROM redis:5.0.4 -ENV host localhost -ENV port 6379 -CMD ["sh", "-c", "redis-benchmark --csv -h ${host} -p ${port} ${flags}"] diff --git a/benchmarks/workloads/redisbenchmark/__init__.py b/benchmarks/workloads/redisbenchmark/__init__.py deleted file mode 100644 index 229cef5fa..000000000 --- a/benchmarks/workloads/redisbenchmark/__init__.py +++ /dev/null @@ -1,85 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Redis-benchmark tool.""" - -import re - -OPERATIONS = [ - "PING_INLINE", - "PING_BULK", - "SET", - "GET", - "INCR", - "LPUSH", - "RPUSH", - "LPOP", - "RPOP", - "SADD", - "HSET", - "SPOP", - "LRANGE_100", - "LRANGE_300", - "LRANGE_500", - "LRANGE_600", - "MSET", -] - -METRICS = dict() - -SAMPLE_DATA = """ -"PING_INLINE","48661.80" -"PING_BULK","50301.81" -"SET","48923.68" -"GET","49382.71" -"INCR","49975.02" -"LPUSH","49875.31" -"RPUSH","50276.52" -"LPOP","50327.12" -"RPOP","50556.12" -"SADD","49504.95" -"HSET","49504.95" -"SPOP","50025.02" -"LPUSH (needed to benchmark LRANGE)","48875.86" -"LRANGE_100 (first 100 elements)","33955.86" -"LRANGE_300 (first 300 elements)","16550.81" -"LRANGE_500 (first 450 elements)","13653.74" -"LRANGE_600 (first 600 elements)","11219.57" -"MSET (10 keys)","44682.75" -""" - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# Bind a metric for each operation noted above. -for op in OPERATIONS: - - def bind(metric): - """Bind op to a new scope.""" - - # pylint: disable=unused-argument - def parse(data: str, **kwargs) -> float: - """Operation throughput in requests/sec.""" - regex = r"\"" + metric + r"( .*)?\",\"(\d*.\d*)" - res = re.compile(regex).search(data) - if res: - return float(res.group(2)) - return 0.0 - - parse.__name__ = metric - return parse - - METRICS[op] = bind(op) diff --git a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py b/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py deleted file mode 100644 index 419ced059..000000000 --- a/benchmarks/workloads/redisbenchmark/redisbenchmark_test.py +++ /dev/null @@ -1,51 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser test.""" - -import sys - -import pytest - -from benchmarks.workloads import redisbenchmark - -RESULTS = { - "PING_INLINE": 48661.80, - "PING_BULK": 50301.81, - "SET": 48923.68, - "GET": 49382.71, - "INCR": 49975.02, - "LPUSH": 49875.31, - "RPUSH": 50276.52, - "LPOP": 50327.12, - "RPOP": 50556.12, - "SADD": 49504.95, - "HSET": 49504.95, - "SPOP": 50025.02, - "LRANGE_100": 33955.86, - "LRANGE_300": 16550.81, - "LRANGE_500": 13653.74, - "LRANGE_600": 11219.57, - "MSET": 44682.75 -} - - -def test_metrics(): - """Test all metrics.""" - for (metric, func) in redisbenchmark.METRICS.items(): - res = func(redisbenchmark.sample()) - assert float(res) == RESULTS[metric] - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD deleted file mode 100644 index a3be4fe92..000000000 --- a/benchmarks/workloads/ruby/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -filegroup( - name = "files", - srcs = [ - "Dockerfile", - "Gemfile", - "Gemfile.lock", - "config.ru", - "index.rb", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "Gemfile", - "Gemfile.lock", - "config.ru", - "index.rb", - ], -) diff --git a/benchmarks/workloads/ruby/Dockerfile b/benchmarks/workloads/ruby/Dockerfile deleted file mode 100644 index a9a7a7086..000000000 --- a/benchmarks/workloads/ruby/Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -# example based on https://github.com/errm/fib - -FROM ruby:2.5 - -RUN apt-get update -qq && apt-get install -y build-essential libpq-dev nodejs libsodium-dev - -# Set an environment variable where the Rails app is installed to inside of Docker image -ENV RAILS_ROOT /var/www/app_name -RUN mkdir -p $RAILS_ROOT - -# Set working directory -WORKDIR $RAILS_ROOT - -# Setting env up -ENV RAILS_ENV='production' -ENV RACK_ENV='production' - -# Adding gems -COPY Gemfile Gemfile -COPY Gemfile.lock Gemfile.lock -RUN bundle install --jobs 20 --retry 5 --without development test - -# Adding project files -COPY . . - -EXPOSE $PORT -STOPSIGNAL SIGINT -CMD ["bundle", "exec", "puma", "config.ru"] diff --git a/benchmarks/workloads/ruby/Gemfile b/benchmarks/workloads/ruby/Gemfile deleted file mode 100644 index 8f1bdad6e..000000000 --- a/benchmarks/workloads/ruby/Gemfile +++ /dev/null @@ -1,12 +0,0 @@ -source "https://rubygems.org" -# load a bunch of dependencies to take up memory -gem "sinatra" -gem "puma" -gem "redis" -gem 'rake' -gem 'squid', '~> 1.4' -gem 'cassandra-driver' -gem 'ruby-fann' -gem 'rbnacl' -gem 'bcrypt' -gem "activemerchant"
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock deleted file mode 100644 index ea9f0ea85..000000000 --- a/benchmarks/workloads/ruby/Gemfile.lock +++ /dev/null @@ -1,71 +0,0 @@ -GEM - remote: https://rubygems.org/ - specs: - activemerchant (1.105.0) - activesupport (>= 4.2) - builder (>= 2.1.2, < 4.0.0) - i18n (>= 0.6.9) - nokogiri (~> 1.4) - activesupport (5.2.3) - concurrent-ruby (~> 1.0, >= 1.0.2) - i18n (>= 0.7, < 2) - minitest (~> 5.1) - tzinfo (~> 1.1) - bcrypt (3.1.13) - builder (3.2.4) - cassandra-driver (3.2.3) - ione (~> 1.2) - concurrent-ruby (1.1.5) - ffi (1.12.2) - i18n (1.6.0) - concurrent-ruby (~> 1.0) - ione (1.2.4) - mini_portile2 (2.4.0) - minitest (5.11.3) - mustermann (1.0.3) - nokogiri (1.10.8) - mini_portile2 (~> 2.4.0) - pdf-core (0.7.0) - prawn (2.2.2) - pdf-core (~> 0.7.0) - ttfunk (~> 1.5) - puma (3.12.4) - rack (2.2.2) - rack-protection (2.0.5) - rack - rake (12.3.3) - rbnacl (7.1.1) - ffi - redis (4.1.1) - ruby-fann (1.2.6) - sinatra (2.0.5) - mustermann (~> 1.0) - rack (~> 2.0) - rack-protection (= 2.0.5) - tilt (~> 2.0) - squid (1.4.1) - activesupport (>= 4.0) - prawn (~> 2.2) - thread_safe (0.3.6) - tilt (2.0.9) - ttfunk (1.5.1) - tzinfo (1.2.5) - thread_safe (~> 0.1) - -PLATFORMS - ruby - -DEPENDENCIES - activemerchant - bcrypt - cassandra-driver - puma - rake - rbnacl - redis - ruby-fann - sinatra - squid (~> 1.4) - -BUNDLED WITH - 1.17.1 diff --git a/benchmarks/workloads/ruby/config.ru b/benchmarks/workloads/ruby/config.ru deleted file mode 100755 index fbd5acc82..000000000 --- a/benchmarks/workloads/ruby/config.ru +++ /dev/null @@ -1,2 +0,0 @@ -require './index' -run Sinatra::Application
\ No newline at end of file diff --git a/benchmarks/workloads/ruby/index.rb b/benchmarks/workloads/ruby/index.rb deleted file mode 100755 index 5fa85af93..000000000 --- a/benchmarks/workloads/ruby/index.rb +++ /dev/null @@ -1,14 +0,0 @@ -require "sinatra" -require "puma" -require "redis" -require "rake" -require "squid" -require "cassandra" -require "ruby-fann" -require "rbnacl" -require "bcrypt" -require "activemerchant" - -get "/" do - "Hello World!" -end
\ No newline at end of file diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD deleted file mode 100644 index 72ed9403d..000000000 --- a/benchmarks/workloads/ruby_template/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "Gemfile", - "Gemfile.lock", - "config.ru", - "index.erb", - "main.rb", - ], -) diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD deleted file mode 100644 index a70873065..000000000 --- a/benchmarks/workloads/sleep/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/sleep/Dockerfile b/benchmarks/workloads/sleep/Dockerfile deleted file mode 100644 index 24c72e07a..000000000 --- a/benchmarks/workloads/sleep/Dockerfile +++ /dev/null @@ -1,3 +0,0 @@ -FROM alpine:latest - -CMD ["sleep", "315360000"] diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD deleted file mode 100644 index ab2556064..000000000 --- a/benchmarks/workloads/sysbench/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "sysbench", - srcs = ["__init__.py"], -) - -py_test( - name = "sysbench_test", - srcs = ["sysbench_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":sysbench", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/sysbench/Dockerfile b/benchmarks/workloads/sysbench/Dockerfile deleted file mode 100644 index 8225e0e14..000000000 --- a/benchmarks/workloads/sysbench/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -FROM ubuntu:18.04 - -RUN set -x \ - && apt-get update \ - && apt-get install -y \ - sysbench \ - && rm -rf /var/lib/apt/lists/* - -# Parameterize the tests. -ENV test cpu -ENV threads 1 -ENV options "" - -# run sysbench once as a warm-up and take the second result -CMD ["sh", "-c", "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null && \ -sysbench --threads=${threads} ${options} ${test} run"] diff --git a/benchmarks/workloads/sysbench/__init__.py b/benchmarks/workloads/sysbench/__init__.py deleted file mode 100644 index de357b4db..000000000 --- a/benchmarks/workloads/sysbench/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Sysbench.""" - -import re - -STD_REGEX = r"events per second:\s*(\d*.?\d*)\n" -MEM_REGEX = r"Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)" -ALT_REGEX = r"execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)" -AVG_REGEX = r"avg:[^\n^\d]*(\d*\.?\d*)" - -SAMPLE_CPU_DATA = """ -sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) - -Running the test with following options: -Number of threads: 8 -Initializing random number generator from current time - - -Prime numbers limit: 10000 - -Initializing worker threads... - -Threads started! - -CPU speed: - events per second: 9093.38 - -General statistics: - total time: 10.0007s - total number of events: 90949 - -Latency (ms): - min: 0.64 - avg: 0.88 - max: 24.65 - 95th percentile: 1.55 - sum: 79936.91 - -Threads fairness: - events (avg/stddev): 11368.6250/831.38 - execution time (avg/stddev): 9.9921/0.01 -""" - -SAMPLE_MEMORY_DATA = """ -sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) - -Running the test with following options: -Number of threads: 8 -Initializing random number generator from current time - - -Running memory speed test with the following options: - block size: 1KiB - total size: 102400MiB - operation: write - scope: global - -Initializing worker threads... - -Threads started! - -Total operations: 47999046 (9597428.64 per second) - -46874.07 MiB transferred (9372.49 MiB/sec) - - -General statistics: - total time: 5.0001s - total number of events: 47999046 - -Latency (ms): - min: 0.00 - avg: 0.00 - max: 0.21 - 95th percentile: 0.00 - sum: 33165.91 - -Threads fairness: - events (avg/stddev): 5999880.7500/111242.52 - execution time (avg/stddev): 4.1457/0.09 -""" - -SAMPLE_MUTEX_DATA = """ -sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) - -Running the test with following options: -Number of threads: 8 -Initializing random number generator from current time - - -Initializing worker threads... - -Threads started! - - -General statistics: - total time: 3.7869s - total number of events: 8 - -Latency (ms): - min: 3688.56 - avg: 3754.03 - max: 3780.94 - 95th percentile: 3773.42 - sum: 30032.28 - -Threads fairness: - events (avg/stddev): 1.0000/0.00 - execution time (avg/stddev): 3.7540/0.03 -""" - - -# pylint: disable=unused-argument -def sample(test, **kwargs): - switch = { - "cpu": SAMPLE_CPU_DATA, - "memory": SAMPLE_MEMORY_DATA, - "mutex": SAMPLE_MUTEX_DATA, - "randwr": SAMPLE_CPU_DATA - } - return switch[test] - - -# pylint: disable=unused-argument -def cpu_events_per_second(data: str, **kwargs) -> float: - """Returns events per second.""" - return float(re.compile(STD_REGEX).search(data).group(1)) - - -# pylint: disable=unused-argument -def memory_ops_per_second(data: str, **kwargs) -> float: - """Returns memory operations per second.""" - return float(re.compile(MEM_REGEX).search(data).group(1)) - - -# pylint: disable=unused-argument -def mutex_time(data: str, count: int, locks: int, threads: int, - **kwargs) -> float: - """Returns normalized mutex time (lower is better).""" - value = float(re.compile(ALT_REGEX).search(data).group(1)) - contention = float(threads) / float(locks) - scale = contention * float(count) / 100000000.0 - return value / scale - - -# pylint: disable=unused-argument -def mutex_deviation(data: str, **kwargs) -> float: - """Returns deviation for threads.""" - return float(re.compile(ALT_REGEX).search(data).group(2)) - - -# pylint: disable=unused-argument -def mutex_latency(data: str, **kwargs) -> float: - """Returns average mutex latency.""" - return float(re.compile(AVG_REGEX).search(data).group(1)) diff --git a/benchmarks/workloads/sysbench/sysbench_test.py b/benchmarks/workloads/sysbench/sysbench_test.py deleted file mode 100644 index 3fb541fd2..000000000 --- a/benchmarks/workloads/sysbench/sysbench_test.py +++ /dev/null @@ -1,34 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Parser test.""" - -import sys - -import pytest - -from benchmarks.workloads import sysbench - - -def test_sysbench_parser(): - """Test the basic parser.""" - assert sysbench.cpu_events_per_second(sysbench.sample("cpu")) == 9093.38 - assert sysbench.memory_ops_per_second(sysbench.sample("memory")) == 9597428.64 - assert sysbench.mutex_time(sysbench.sample("mutex"), 1, 1, - 100000000.0) == 3.754 - assert sysbench.mutex_deviation(sysbench.sample("mutex")) == 0.03 - assert sysbench.mutex_latency(sysbench.sample("mutex")) == 3754.03 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD deleted file mode 100644 index f8c43bca1..000000000 --- a/benchmarks/workloads/syscall/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") -load("//benchmarks:defs.bzl", "test_deps") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "syscall", - srcs = ["__init__.py"], -) - -py_test( - name = "syscall_test", - srcs = ["syscall_test.py"], - python_version = "PY3", - deps = test_deps + [ - ":syscall", - ], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - "syscall.c", - ], -) diff --git a/benchmarks/workloads/syscall/Dockerfile b/benchmarks/workloads/syscall/Dockerfile deleted file mode 100644 index a2088d953..000000000 --- a/benchmarks/workloads/syscall/Dockerfile +++ /dev/null @@ -1,6 +0,0 @@ -FROM gcc:latest -COPY . /usr/src/syscall -WORKDIR /usr/src/syscall -RUN gcc -O2 -o syscall syscall.c -ENV count 1000000 -CMD ["sh", "-c", "./syscall ${count}"] diff --git a/benchmarks/workloads/syscall/__init__.py b/benchmarks/workloads/syscall/__init__.py deleted file mode 100644 index dc9028faa..000000000 --- a/benchmarks/workloads/syscall/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Simple syscall test.""" - -import re - -SAMPLE_DATA = "Called getpid syscall 1000000 times: 1117 ms, 500 ns each." - - -# pylint: disable=unused-argument -def sample(**kwargs) -> str: - return SAMPLE_DATA - - -# pylint: disable=unused-argument -def syscall_time_ns(data: str, **kwargs) -> int: - """Returns average system call time.""" - return float(re.compile(r"(\d+)\sns each.").search(data).group(1)) diff --git a/benchmarks/workloads/syscall/syscall.c b/benchmarks/workloads/syscall/syscall.c deleted file mode 100644 index ded030397..000000000 --- a/benchmarks/workloads/syscall/syscall.c +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#define _GNU_SOURCE -#include <stdio.h> -#include <stdlib.h> -#include <sys/syscall.h> -#include <sys/types.h> -#include <time.h> -#include <unistd.h> - -// Short program that calls getpid() a number of times and outputs time -// diference from the MONOTONIC clock. -int main(int argc, char** argv) { - struct timespec start, stop; - long result; - char buf[80]; - - if (argc < 2) { - printf("Usage:./syscall NUM_TIMES_TO_CALL"); - return 1; - } - - if (clock_gettime(CLOCK_MONOTONIC, &start)) return 1; - - long loops = atoi(argv[1]); - for (long i = 0; i < loops; i++) { - syscall(SYS_gettimeofday, 0, 0); - } - - if (clock_gettime(CLOCK_MONOTONIC, &stop)) return 1; - - if ((stop.tv_nsec - start.tv_nsec) < 0) { - result = (stop.tv_sec - start.tv_sec - 1) * 1000; - result += (stop.tv_nsec - start.tv_nsec + 1000000000) / (1000 * 1000); - } else { - result = (stop.tv_sec - start.tv_sec) * 1000; - result += (stop.tv_nsec - start.tv_nsec) / (1000 * 1000); - } - - printf("Called getpid syscall %d times: %lu ms, %lu ns each.\n", loops, - result, result * 1000000 / loops); - - return 0; -} diff --git a/benchmarks/workloads/syscall/syscall_test.py b/benchmarks/workloads/syscall/syscall_test.py deleted file mode 100644 index 72f027de1..000000000 --- a/benchmarks/workloads/syscall/syscall_test.py +++ /dev/null @@ -1,27 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import pytest - -from benchmarks.workloads import syscall - - -def test_syscall_time_ns(): - assert syscall.syscall_time_ns(syscall.sample()) == 500 - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD deleted file mode 100644 index a7b7742f4..000000000 --- a/benchmarks/workloads/tensorflow/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -py_library( - name = "tensorflow", - srcs = ["__init__.py"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], -) diff --git a/benchmarks/workloads/tensorflow/__init__.py b/benchmarks/workloads/tensorflow/__init__.py deleted file mode 100644 index b5ec213f8..000000000 --- a/benchmarks/workloads/tensorflow/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A Tensorflow example.""" - - -# pylint: disable=unused-argument -def run_time(value, **kwargs): - """Returns the startup and runtime of the Tensorflow workload in seconds.""" - return value diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD deleted file mode 100644 index eba23d325..000000000 --- a/benchmarks/workloads/true/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "pkg_tar") - -package( - default_visibility = ["//benchmarks:__subpackages__"], - licenses = ["notice"], -) - -pkg_tar( - name = "tar", - srcs = [ - "Dockerfile", - ], - extension = "tar", -) diff --git a/debian/BUILD b/debian/BUILD new file mode 100644 index 000000000..331f44a5c --- /dev/null +++ b/debian/BUILD @@ -0,0 +1,59 @@ +load("//tools:defs.bzl", "pkg_deb", "pkg_tar") + +package(licenses = ["notice"]) + +pkg_tar( + name = "debian-bin", + srcs = [ + "//runsc", + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", + ], + mode = "0755", + package_dir = "/usr/bin", +) + +pkg_tar( + name = "debian-data", + extension = "tar.gz", + deps = [ + ":debian-bin", + "//shim:config", + ], +) + +genrule( + name = "debian-version", + # Note that runsc must appear in the srcs parameter and not the tools + # parameter, otherwise it will not be stamped. This is reasonable, as tools + # may be encoded differently in the build graph (cached more aggressively + # because they are assumes to be hermetic). + srcs = ["//runsc"], + outs = ["version.txt"], + # Note that the little dance here is necessary because files in the $(SRCS) + # attribute are not executable by default, and we can't touch in place. + cmd = "cp $(location //runsc:runsc) $(@D)/runsc && \ + chmod a+x $(@D)/runsc && \ + $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ + rm -f $(@D)/runsc", + stamp = 1, +) + +pkg_deb( + name = "debian", + architecture = "amd64", + data = ":debian-data", + # Note that the description_file will be flatten (all newlines removed), + # and therefore it is kept to a simple one-line description. The expected + # format for debian packages is "short summary\nLonger explanation of + # tool." and this is impossible with the flattening. + description_file = "description", + homepage = "https://gvisor.dev/", + maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", + package = "runsc", + postinst = "postinst.sh", + version_file = ":version.txt", + visibility = [ + "//visibility:public", + ], +) diff --git a/runsc/debian/description b/debian/description index 9e8e08805..9e8e08805 100644 --- a/runsc/debian/description +++ b/debian/description diff --git a/runsc/debian/postinst.sh b/debian/postinst.sh index dc7aeee87..6a326f823 100755 --- a/runsc/debian/postinst.sh +++ b/debian/postinst.sh @@ -18,7 +18,14 @@ if [ "$1" != configure ]; then exit 0 fi +# Update docker configuration. if [ -f /etc/docker/daemon.json ]; then runsc install - systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 + if systemctl is-active -q docker; then + systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 + fi fi + +# For containerd-based installers, we don't automatically update the +# configuration. If it uses a v2 shim, then it will find the package binaries +# automatically when provided the appropriate annotation. diff --git a/g3doc/BUILD b/g3doc/BUILD index c315d38be..f91a77b6f 100644 --- a/g3doc/BUILD +++ b/g3doc/BUILD @@ -31,7 +31,7 @@ doc( category = "Project", permalink = "/community/", subcategory = "Community", - weight = "95", + weight = "10", ) doc( @@ -40,5 +40,5 @@ doc( category = "Project", permalink = "/community/style_guide/", subcategory = "Community", - weight = "10", + weight = "99", ) diff --git a/g3doc/README.md b/g3doc/README.md index 304a91493..22bfb15f7 100644 --- a/g3doc/README.md +++ b/g3doc/README.md @@ -117,9 +117,7 @@ for more information on filesystem bundles. `runsc` implements multiple commands that perform various functions such as starting, stopping, listing, and querying the status of containers. -### Sentry - -<a name="sentry"></a> <!-- For deep linking. --> +### Sentry {#sentry} The Sentry is the largest component of gVisor. It can be thought of as a application kernel. The Sentry implements all the kernel functionality needed by @@ -136,9 +134,7 @@ calls it makes. For example, the Sentry is not able to open files directly; file system operations that extend beyond the sandbox (not internal `/proc` files, pipes, etc) are sent to the Gofer, described below. -### Gofer - -<a name="gofer"></a> <!-- For deep linking. --> +### Gofer {#gofer} The Gofer is a standard host process which is started with each container and communicates with the Sentry via the [9P protocol][9p] over a socket or shared @@ -146,13 +142,13 @@ memory channel. The Sentry process is started in a restricted seccomp container without access to file system resources. The Gofer mediates all access to the these resources, providing an additional level of isolation. -### Application +### Application {#application} The application is a normal Linux binary provided to gVisor in an OCI runtime bundle. gVisor aims to provide an environment equivalent to Linux v4.4, so applications should be able to run unmodified. However, gVisor does not presently implement every system call, `/proc` file, or `/sys` file so some -incompatibilities may occur. See [Commpatibility](./user_guide/compatibility.md) +incompatibilities may occur. See [Compatibility](./user_guide/compatibility.md) for more information. [9p]: https://en.wikipedia.org/wiki/9P_(protocol) diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md index 39dbb0045..b981f0c01 100644 --- a/g3doc/architecture_guide/performance.md +++ b/g3doc/architecture_guide/performance.md @@ -30,7 +30,7 @@ is distinct from **structural costs**. Improvements here are ongoing and driven by the workloads that matter to gVisor users and contributors. This page provides a guide for understanding baseline performance, and calls out -distint **structural costs** and **implementation costs**, highlighting where +distinct **structural costs** and **implementation costs**, highlighting where improvements are possible and not possible. While we include a variety of workloads here, it’s worth emphasizing that gVisor @@ -211,7 +211,7 @@ url="/performance/applications.csv" title="perf.py http.(node|ruby) The above figure shows the result of simple `node` and `ruby` web services that render a template upon receiving a request. Because these synthetic benchmarks -do minimal work per request, must like the `redis` case, they suffer from high +do minimal work per request, much like the `redis` case, they suffer from high overheads. In practice, the more work an application does the smaller the impact of **structural costs** become. diff --git a/g3doc/architecture_guide/resources.md b/g3doc/architecture_guide/resources.md index 1dec37bd1..fc997d40c 100644 --- a/g3doc/architecture_guide/resources.md +++ b/g3doc/architecture_guide/resources.md @@ -19,12 +19,12 @@ sandboxed process: Much like a Virtual Machine (VM), a gVisor sandbox appears as an opaque process on the system. Processes within the sandbox do not manifest as processes on the -host system, and process-level interactions within the sandbox requires entering +host system, and process-level interactions within the sandbox require entering the sandbox (e.g. via a [Docker exec][exec]). ## Networking -The sandbox attaches a network endpoint to the system, but runs it's own network +The sandbox attaches a network endpoint to the system, but runs its own network stack. All network resources, other than packets in flight on the host, exist only inside the sandbox, bound by relevant resource limits. diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md index b99b86332..9363d834c 100644 --- a/g3doc/architecture_guide/security.md +++ b/g3doc/architecture_guide/security.md @@ -104,7 +104,7 @@ interactions with a guest operating system and a set of virtualized hardware devices. These hardware devices are then implemented via the host System API by a Virtual Machine Monitor (VMM). The Sentry similarly prevents direct interactions by providing its own implementation of the System API that the -application must interact with. Applications are not able to to directly craft +application must interact with. Applications are not able to directly craft specific arguments or flags for the host System API, or interact directly with host primitives. diff --git a/g3doc/style.md b/g3doc/style.md index d10549fe9..8258b0233 100644 --- a/g3doc/style.md +++ b/g3doc/style.md @@ -46,6 +46,15 @@ protected. Each field or variable protected by a mutex should state as such in a comment on the field or variable declaration. +### Function comments + +Functions with special entry conditions (e.g., a lock must be held) should state +these conditions in a `Preconditions:` comment block. One condition per line; +multiple conditions are specified with a bullet (`*`). + +Functions with notable exit conditions (e.g., a `Done` function must eventually +be called by the caller) can similarly have a `Postconditions:` block. + ### Unused returns Unused returns should be explicitly ignored with underscores. If there is a diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md index 89df65e99..514fe3918 100644 --- a/g3doc/user_guide/FAQ.md +++ b/g3doc/user_guide/FAQ.md @@ -74,11 +74,10 @@ directories. ### I'm getting an error like: `panic: unable to attach: operation not permitted` or `fork/exec /proc/self/exe: invalid argument: unknown` {#runsc-perms} -Make sure that permissions and the owner is correct on the `runsc` binary. +Make sure that permissions is correct on the `runsc` binary. ```bash -sudo chown root:root /usr/local/bin/runsc -sudo chmod 0755 /usr/local/bin/runsc +sudo chmod a+rx /usr/local/bin/runsc ``` ### I'm getting an error like `mount submount "/etc/hostname": creating mount with source ".../hostname": input/output error: unknown.` {#memlock} @@ -96,6 +95,30 @@ containerd. See [issue #1765](https://gvisor.dev/issue/1765) for more details. +### I'm getting an error like `RuntimeHandler "runsc" not supported` {#runtime-handler} + +This error indicates that the Kubernetes CRI runtime was not set up to handle +`runsc` as a runtime handler. Please ensure that containerd configuration has +been created properly and containerd has been restarted. See the +[containerd quick start](containerd/quick_start.md) for more details. + +If you have ensured that containerd has been set up properly and you used +kubeadm to create your cluster please check if Docker is also installed on that +system. Kubeadm prefers using Docker if both Docker and containerd are +installed. + +Please recreate your cluster and set the `--cni-socket` option on kubeadm +commands. For example: + +```bash +kubeadm init --cni-socket=/var/run/containerd/containerd.sock` ... +``` + +To fix an existing cluster edit the `/var/lib/kubelet/kubeadm-flags.env` file +and set the `--container-runtime` flag to `remote` and set the +`--container-runtime-endpoint` flag to point to the containerd socket. e.g. +`/var/run/containerd/containerd.sock`. + ### My container cannot resolve another container's name when using Docker user defined bridge {#docker-bridge} This is normally indicated by errors like `bad address 'container-name'` when diff --git a/g3doc/user_guide/containerd/BUILD b/g3doc/user_guide/containerd/BUILD new file mode 100644 index 000000000..979d46105 --- /dev/null +++ b/g3doc/user_guide/containerd/BUILD @@ -0,0 +1,33 @@ +load("//website:defs.bzl", "doc") + +package( + default_visibility = ["//website:__pkg__"], + licenses = ["notice"], +) + +doc( + name = "quick_start", + src = "quick_start.md", + category = "User Guide", + permalink = "/docs/user_guide/containerd/quick_start/", + subcategory = "Containerd", + weight = "10", +) + +doc( + name = "configuration", + src = "configuration.md", + category = "User Guide", + permalink = "/docs/user_guide/containerd/configuration/", + subcategory = "Containerd", + weight = "90", +) + +doc( + name = "containerd_11", + src = "containerd_11.md", + category = "User Guide", + permalink = "/docs/user_guide/containerd/containerd_11/", + subcategory = "Containerd", + weight = "99", +) diff --git a/g3doc/user_guide/containerd/configuration.md b/g3doc/user_guide/containerd/configuration.md new file mode 100644 index 000000000..5d485c24b --- /dev/null +++ b/g3doc/user_guide/containerd/configuration.md @@ -0,0 +1,70 @@ +# Containerd Advanced Configuration + +This document describes how to configure runtime options for +`containerd-shim-runsc-v1`. This follows the +[Containerd Quick Start](./quick_start.md) and requires containerd 1.2 or later. + +### Update `/etc/containerd/config.toml` to point to a configuration file for `containerd-shim-runsc-v1`. + +`containerd-shim-runsc-v1` supports a few different configuration options based +on the version of containerd that is used. For versions >= 1.3, it supports a +configurable `ConfigPath` in the containerd runtime configuration. + +```shell +cat <<EOF | sudo tee /etc/containerd/config.toml +disabled_plugins = ["restart"] +[plugins.linux] + shim_debug = true +[plugins.cri.containerd.runtimes.runsc] + runtime_type = "io.containerd.runsc.v1" +[plugins.cri.containerd.runtimes.runsc.options] + TypeUrl = "io.containerd.runsc.v1.options" + # containerd 1.3 only! + ConfigPath = "/etc/containerd/runsc.toml" +EOF +``` + +When you are done restart containerd to pick up the new configuration files. + +```shell +sudo systemctl restart containerd +``` + +### Configure `/etc/containerd/runsc.toml` + +> Note: For containerd 1.2, the config file should named `config.toml` and +> located in the runtime root. By default, this is `/run/containerd/runsc`. + +The set of options that can be configured can be found in +[options.go](https://github.com/google/gvisor/blob/master/pkg/shim/v2/options/options.go). + +#### Example: Enable the KVM platform + +gVisor enables the use of a number of platforms. This example shows how to +configure `containerd-shim-runsc-v1` to use gvisor with the KVM platform. + +Find out more about platform in the +[Platforms Guide](../../architecture_guide/platforms.md). + +```shell +cat <<EOF | sudo tee /etc/containerd/runsc.toml +[runsc_config] +platform = "kvm" +EOF +``` + +### Example: Enable gVisor debug logging + +gVisor debug logging can be enabled by setting the `debug` and `debug-log` flag. +The shim will replace "%ID%" with the container ID, and "%COMMAND%" with the +runsc command (run, boot, etc.) in the path of the `debug-log` flag. + +Find out more about debugging in the [debugging guide](../debugging.md). + +```shell +cat <<EOF | sudo tee /etc/containerd/runsc.toml +[runsc_config] + debug=true + debug-log=/var/log/%ID%/gvisor.%COMMAND%.log +EOF +``` diff --git a/g3doc/user_guide/containerd/containerd_11.md b/g3doc/user_guide/containerd/containerd_11.md new file mode 100644 index 000000000..50befbdf4 --- /dev/null +++ b/g3doc/user_guide/containerd/containerd_11.md @@ -0,0 +1,163 @@ +# Older Versions (containerd 1.1) + +This document describes how to install and run the `gvisor-containerd-shim` +using the untrusted workload CRI extension. This requires `containerd` 1.1 or +later. + +*Note: The untrusted workload CRI extension is deprecated by containerd and +`gvisor-containerd-shim` is maintained on a best-effort basis. If you are using +containerd 1.2+, please see the +[containerd 1.2+ documentation](./quick_start.md) and use +`containerd-shim-runsc-v1`.* + +## Requirements + +- **runsc** and **gvisor-containerd-shim**: See the + [installation guide](/docs/user_guide/install/). +- **containerd**: See the [containerd website](https://containerd.io/) for + information on how to install containerd. + +## Configure containerd + +Create the configuration for the gvisor shim in +`/etc/containerd/gvisor-containerd-shim.toml`: + +```shell +cat <<EOF | sudo tee /etc/containerd/gvisor-containerd-shim.toml +# This is the path to the default runc containerd-shim. +runc_shim = "/usr/local/bin/containerd-shim" +EOF +``` + +Update `/etc/containerd/config.toml`. Be sure to update the path to +`gvisor-containerd-shim` and `runsc` if necessary: + +```shell +cat <<EOF | sudo tee /etc/containerd/config.toml +disabled_plugins = ["restart"] +[plugins.linux] + shim = "/usr/local/bin/gvisor-containerd-shim" + shim_debug = true +[plugins.cri.containerd.untrusted_workload_runtime] + runtime_type = "io.containerd.runtime.v1.linux" + runtime_engine = "/usr/local/bin/runsc" + runtime_root = "/run/containerd/runsc" +EOF +``` + +Restart `containerd`: + +```shell +sudo systemctl restart containerd +``` + +## Usage + +You can run containers in gVisor via containerd's CRI. + +### Install crictl + +Download and install the `crictl` binary: + +```shell +{ +wget https://github.com/kubernetes-sigs/cri-tools/releases/download/v1.13.0/crictl-v1.13.0-linux-amd64.tar.gz +tar xf crictl-v1.13.0-linux-amd64.tar.gz +sudo mv crictl /usr/local/bin +} +``` + +Write the `crictl` configuration file: + +```shell +cat <<EOF | sudo tee /etc/crictl.yaml +runtime-endpoint: unix:///run/containerd/containerd.sock +EOF +``` + +### Create the nginx Sandbox in gVisor + +Pull the nginx image: + +```shell +sudo crictl pull nginx +``` + +Create the sandbox creation request: + +```shell +cat <<EOF | tee sandbox.json +{ + "metadata": { + "name": "nginx-sandbox", + "namespace": "default", + "attempt": 1, + "uid": "hdishd83djaidwnduwk28bcsb" + }, + "annotations": { + "io.kubernetes.cri.untrusted-workload": "true" + }, + "linux": { + }, + "log_directory": "/tmp" +} +EOF +``` + +Create the pod in gVisor: + +```shell +SANDBOX_ID=$(sudo crictl runp sandbox.json) +``` + +### Run the nginx Container in the Sandbox + +Create the nginx container creation request: + +```shell +cat <<EOF | tee container.json +{ + "metadata": { + "name": "nginx" + }, + "image":{ + "image": "nginx" + }, + "log_path":"nginx.0.log", + "linux": { + } +} +EOF +``` + +Create the nginx container: + +```shell +CONTAINER_ID=$(sudo crictl create ${SANDBOX_ID} container.json sandbox.json) +``` + +Start the nginx container: + +```shell +sudo crictl start ${CONTAINER_ID} +``` + +### Validate the container + +Inspect the created pod: + +```shell +sudo crictl inspectp ${SANDBOX_ID} +``` + +Inspect the nginx container: + +```shell +sudo crictl inspect ${CONTAINER_ID} +``` + +Verify that nginx is running in gVisor: + +```shell +sudo crictl exec ${CONTAINER_ID} dmesg | grep -i gvisor +``` diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md new file mode 100644 index 000000000..2f67eecb3 --- /dev/null +++ b/g3doc/user_guide/containerd/quick_start.md @@ -0,0 +1,176 @@ +# Containerd Quick Start + +This document describes how to install and configure `containerd-shim-runsc-v1` +using the containerd runtime handler support on `containerd` 1.2 or later. + +## Requirements + +- **runsc** and **containerd-shim-runsc-v1**: See the + [installation guide](/docs/user_guide/install/). +- **containerd**: See the [containerd website](https://containerd.io/) for + information on how to install containerd. + +## Configure containerd + +Update `/etc/containerd/config.toml`. Make sure `containerd-shim-runsc-v1` is in +`${PATH}` or in the same directory as `containerd` binary. + +```shell +cat <<EOF | sudo tee /etc/containerd/config.toml +disabled_plugins = ["restart"] +[plugins.linux] + shim_debug = true +[plugins.cri.containerd.runtimes.runsc] + runtime_type = "io.containerd.runsc.v1" +EOF +``` + +Restart `containerd`: + +```shell +sudo systemctl restart containerd +``` + +## Usage + +You can run containers in gVisor via containerd's CRI. + +### Install crictl + +Download and install the `crictl`` binary: + +```shell +{ +wget https://github.com/kubernetes-sigs/cri-tools/releases/download/v1.13.0/crictl-v1.13.0-linux-amd64.tar.gz +tar xf crictl-v1.13.0-linux-amd64.tar.gz +sudo mv crictl /usr/local/bin +} +``` + +Write the `crictl` configuration file: + +```shell +cat <<EOF | sudo tee /etc/crictl.yaml +runtime-endpoint: unix:///run/containerd/containerd.sock +EOF +``` + +### Create the nginx sandbox in gVisor + +Pull the nginx image: + +```shell +sudo crictl pull nginx +``` + +Create the sandbox creation request: + +```shell +cat <<EOF | tee sandbox.json +{ + "metadata": { + "name": "nginx-sandbox", + "namespace": "default", + "attempt": 1, + "uid": "hdishd83djaidwnduwk28bcsb" + }, + "linux": { + }, + "log_directory": "/tmp" +} +EOF +``` + +Create the pod in gVisor: + +```shell +SANDBOX_ID=$(sudo crictl runp --runtime runsc sandbox.json) +``` + +### Run the nginx container in the sandbox + +Create the nginx container creation request: + +```shell +cat <<EOF | tee container.json +{ + "metadata": { + "name": "nginx" + }, + "image":{ + "image": "nginx" + }, + "log_path":"nginx.0.log", + "linux": { + } +} +EOF +``` + +Create the nginx container: + +```shell +CONTAINER_ID=$(sudo crictl create ${SANDBOX_ID} container.json sandbox.json) +``` + +Start the nginx container: + +```shell +sudo crictl start ${CONTAINER_ID} +``` + +### Validate the container + +Inspect the created pod: + +```shell +sudo crictl inspectp ${SANDBOX_ID} +``` + +Inspect the nginx container: + +```shell +sudo crictl inspect ${CONTAINER_ID} +``` + +Verify that nginx is running in gVisor: + +```shell +sudo crictl exec ${CONTAINER_ID} dmesg | grep -i gvisor +``` + +### Set up the Kubernetes RuntimeClass + +Install the RuntimeClass for gVisor: + +```shell +cat <<EOF | kubectl apply -f - +apiVersion: node.k8s.io/v1beta1 +kind: RuntimeClass +metadata: + name: gvisor +handler: runsc +EOF +``` + +Create a Pod with the gVisor RuntimeClass: + +```shell +cat <<EOF | kubectl apply -f - +apiVersion: v1 +kind: Pod +metadata: + name: nginx-gvisor +spec: + runtimeClassName: gvisor + containers: + - name: nginx + image: nginx +EOF +``` + +Verify that the Pod is running: + +```shell +kubectl get pod nginx-gvisor -o wide +``` diff --git a/g3doc/user_guide/debugging.md b/g3doc/user_guide/debugging.md index 0525fd5c0..54fdce34f 100644 --- a/g3doc/user_guide/debugging.md +++ b/g3doc/user_guide/debugging.md @@ -129,3 +129,13 @@ go tool pprof -top /usr/local/bin/runsc /tmp/cpu.prof ``` [pprof]: https://github.com/google/pprof/blob/master/doc/README.md + +### Docker Proxy + +When forwarding a port to the container, Docker will likely route traffic +through the [docker-proxy][]. This proxy may make profiling noisy, so it can be +helpful to bypass it. Do so by sending traffic directly to the container IP and +port. e.g., if the `docker0` IP is `192.168.9.1`, the container IP is likely a +subsequent IP, such as `192.168.9.2`. + +[docker-proxy]: https://windsock.io/the-docker-proxy/ diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index 9afdd264d..abb9e8582 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -5,6 +5,68 @@ > Note: gVisor supports only x86\_64 and requires Linux 4.14.77+ > ([older Linux](./networking.md#gso)). +## Install latest release {#install-latest} + +To download and install the latest release manually follow these steps: + +```bash +( + set -e + URL=https://storage.googleapis.com/gvisor/releases/release/latest + wget ${URL}/runsc ${URL}/runsc.sha512 + sha512sum -c runsc.sha512 + rm -f runsc.sha512 + sudo mv runsc /usr/local/bin + sudo chmod a+rx /usr/local/bin/runsc +) +``` + +To install gVisor with Docker, run the following commands: + +```bash +/usr/local/bin/runsc install +sudo systemctl restart docker +docker run --rm --runtime=runsc hello-world +``` + +For more details about using gVisor with Docker, see +[Docker Quick Start](./quick_start/docker.md) + +Note: It is important to copy `runsc` to a location that is readable and +executable to all users, since `runsc` executes itself as user `nobody` to avoid +unnecessary privileges. The `/usr/local/bin` directory is a good place to put +the `runsc` binary. + +## Install from an `apt` repository + +First, appropriate dependencies must be installed to allow `apt` to install +packages via https: + +```bash +sudo apt-get update && \ +sudo apt-get install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common +``` + +Next, the configure the key used to sign archives and the repository: + +```bash +curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" +``` + +Now the runsc package can be installed: + +```bash +sudo apt-get update && sudo apt-get install -y runsc +``` + +If you have Docker installed, it will be automatically configured. + ## Versions The `runsc` binaries and repositories are available in multiple versions and @@ -21,12 +83,16 @@ Binaries are available for every commit on the `master` branch, and are available at the following URL: `https://storage.googleapis.com/gvisor/releases/master/latest/runsc` +`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512` -Checksums for the release binary are at: +You can use this link with the steps described in +[Install latest release](#install-latest). -`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512` +For `apt` installation, use the `master` to configure the repository: -For `apt` installation, use the `master` as the `${DIST}` below. +```bash +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases master main" +``` ### Nightly @@ -34,18 +100,22 @@ Nightly releases are built most nights from the master branch, and are available at the following URL: `https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc` - -Checksums for the release binary are at: - `https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc.sha512` +You can use this link with the steps described in +[Install latest release](#install-latest). + Specific nightly releases can be found at: `https://storage.googleapis.com/gvisor/releases/nightly/${yyyy-mm-dd}/runsc` Note that a release may not be available for every day. -For `apt` installation, use the `nightly` as the `${DIST}` below. +For `apt` installation, use the `nightly` to configure the repository: + +```bash +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases nightly main" +``` ### Latest release @@ -53,105 +123,47 @@ The latest official release is available at the following URL: `https://storage.googleapis.com/gvisor/releases/release/latest` -For `apt` installation, use the `release` as the `${DIST}` below. - -### Specific release - -A given release release is available at the following URL: - -`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}` - -See the [releases][releases] page for information about specific releases. - -For `apt` installation of a specific release, which may include point updates, -use the date of the release, e.g. `${yyyymmdd}`, as the `${DIST}` below. - -> Note: only newer releases may be available as `apt` repositories. - -### Point release - -A given point release is available at the following URL: - -`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}` - -Note that `apt` installation of a specific point release is not supported. - -## Install from an `apt` repository +You can use this link with the steps described in +[Install latest release](#install-latest). -First, appropriate dependencies must be installed to allow `apt` to install -packages via https: +For `apt` installation, use the `release` to configure the repository: ```bash -sudo apt-get update && \ -sudo apt-get install -y \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" ``` -Next, the key used to sign archives should be added to your `apt` keychain: - -```bash -curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - -``` +### Specific release -Based on the release type, you will need to substitute `${DIST}` below, using -one of: +A given release release is available at the following URL: -* `master`: For HEAD. -* `nightly`: For nightly releases. -* `release`: For the latest release. -* `${yyyymmdd}`: For a specific releases (see above). +`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}` -The repository for the release you wish to install should be added: +You can use this link with the steps described in +[Install latest release](#install-latest). -```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases ${DIST} main" -``` +See the [releases](https://github.com/google/gvisor/releases) page for +information about specific releases. -For example, to install the latest official release, you can use: +For `apt` installation of a specific release, which may include point updates, +use the date of the release for repository, e.g. `${yyyymmdd}`. ```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases yyyymmdd main" ``` -Now the runsc package can be installed: - -```bash -sudo apt-get update && sudo apt-get install -y runsc -``` +> Note: only newer releases may be available as `apt` repositories. -If you have Docker installed, it will be automatically configured. +### Point release -## Install directly +A given point release is available at the following URL: -The binary URLs provided above can be used to install directly. For example, the -latest nightly binary can be downloaded, validated, and placed in an appropriate -location by running: +`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}` -```bash -( - set -e - URL=https://storage.googleapis.com/gvisor/releases/nightly/latest - wget ${URL}/runsc - wget ${URL}/runsc.sha512 - sha512sum -c runsc.sha512 - rm -f runsc.sha512 - sudo mv runsc /usr/local/bin - sudo chown root:root /usr/local/bin/runsc - sudo chmod 0755 /usr/local/bin/runsc -) -``` +You can use this link with the steps described in +[Install latest release](#install-latest). -**It is important to copy this binary to a location that is accessible to all -users, and ensure it is executable by all users**, since `runsc` executes itself -as user `nobody` to avoid unnecessary privileges. The `/usr/local/bin` directory -is a good place to put the `runsc` binary. +Note that `apt` installation of a specific point release is not supported. After installation, try out `runsc` by following the [Docker Quick Start](./quick_start/docker.md) or [OCI Quick Start](./quick_start/oci.md). - -[releases]: https://github.com/google/gvisor/releases diff --git a/g3doc/user_guide/networking.md b/g3doc/user_guide/networking.md index 62def5a90..95f675633 100644 --- a/g3doc/user_guide/networking.md +++ b/g3doc/user_guide/networking.md @@ -2,9 +2,9 @@ [TOC] -gVisor implements its own network stack called netstack. All aspects -of the network stack are handled inside the Sentry — including TCP connection -state, control messages, and packet assembly — keeping it isolated from the host +gVisor implements its own network stack called netstack. All aspects of the +network stack are handled inside the Sentry — including TCP connection state, +control messages, and packet assembly — keeping it isolated from the host network stack. Data link layer packets are written directly to the virtual device inside the network namespace setup by Docker or Kubernetes. diff --git a/g3doc/user_guide/quick_start/docker.md b/g3doc/user_guide/quick_start/docker.md index 6ad594ecc..ee842e453 100644 --- a/g3doc/user_guide/quick_start/docker.md +++ b/g3doc/user_guide/quick_start/docker.md @@ -22,18 +22,6 @@ named "runsc" by default. sudo runsc install ``` -You may also wish to install a runtime entry for debugging. The `runsc install` -command can accept options that will be passed to the runtime when it is invoked -by Docker. - -```bash -sudo runsc install --runtime runsc-debug -- \ - --debug \ - --debug-log=/tmp/runsc-debug.log \ - --strace \ - --log-packets -``` - You must restart the Docker daemon after installing the runtime. Typically this is done via `systemd`: @@ -85,6 +73,21 @@ $ docker run --runtime=runsc -it ubuntu dmesg Note that this is easily replicated by an attacker so applications should never use `dmesg` to verify the runtime in a security sensitive context. +## Options + +You may also wish to install a runtime entry with different options. The `runsc +install` command can accept flags that will be passed to the runtime when it is +invoked by Docker. For example, to install a runtime with debugging enabled, run +the following: + +```bash +sudo runsc install --runtime runsc-debug -- \ + --debug \ + --debug-log=/tmp/runsc-debug.log \ + --strace \ + --log-packets +``` + Next, look at the different options available for gVisor: [platform][platforms], [network][networking], [filesystem][filesystem]. diff --git a/g3doc/user_guide/quick_start/kubernetes.md b/g3doc/user_guide/quick_start/kubernetes.md index f875d8002..395cd4b71 100644 --- a/g3doc/user_guide/quick_start/kubernetes.md +++ b/g3doc/user_guide/quick_start/kubernetes.md @@ -6,17 +6,15 @@ with Kubernetes. ## Using Minikube gVisor can run sandboxed containers in a Kubernetes cluster with Minikube. After -the gVisor addon is enabled, pods with `io.kubernetes.cri.untrusted-workload` +the gVisor addon is enabled, pods with a `gvisor` [Runtime Class][runtimeclass] set to true will execute with `runsc`. Follow [these instructions][minikube] to enable gVisor addon. ## Using Containerd -You can also setup Kubernetes nodes to run pods in gvisor using the -[containerd][containerd] CRI runtime and the `gvisor-containerd-shim`. You can -use either the `io.kubernetes.cri.untrusted-workload` annotation or -[RuntimeClass][runtimeclass] to run Pods with `runsc`. You can find instructions -[here][gvisor-containerd-shim]. +You can also setup Kubernetes nodes to run pods in gVisor using +[containerd][containerd] and the gVisor containerd shim. You can find +instructions in the [Containerd Quick Start][gvisor-containerd]. ## Using GKE Sandbox @@ -31,6 +29,6 @@ WordPress site. You can view the full documentation [here][gke-sandbox-docs]. [gke]: https://cloud.google.com/kubernetes-engine/ [gke-sandbox]: https://cloud.google.com/kubernetes-engine/sandbox/ [gke-sandbox-docs]: https://cloud.google.com/kubernetes-engine/docs/how-to/sandbox-pods -[gvisor-containerd-shim]: https://github.com/google/gvisor-containerd-shim +[gvisor-containerd]: /docs/user_guide/containerd/quick_start/ [runtimeclass]: https://kubernetes.io/docs/concepts/containers/runtime-class/ [wordpress-quick]: /docs/tutorials/kubernetes/ diff --git a/g3doc/user_guide/quick_start/oci.md b/g3doc/user_guide/quick_start/oci.md index 877169145..e7768946b 100644 --- a/g3doc/user_guide/quick_start/oci.md +++ b/g3doc/user_guide/quick_start/oci.md @@ -15,8 +15,8 @@ mkdir bundle cd bundle ``` -Create a root file system for the container. We will use the Docker hello-world -image as the basis for our container. +Create a root file system for the container. We will use the Docker +`hello-world` image as the basis for our container. ```bash mkdir rootfs @@ -24,12 +24,10 @@ docker export $(docker create hello-world) | tar -xf - -C rootfs ``` Next, create an specification file called `config.json` that contains our -container specification. We will update the default command it runs to `/hello` -in the `hello-world` container. +container specification. We tell the container to run the `/hello` program. ```bash -runsc spec -sed -i 's;"sh";"/hello";' config.json +runsc spec -- /hello ``` Finally run the container. diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD index caae98623..f405349b3 100644 --- a/g3doc/user_guide/tutorials/BUILD +++ b/g3doc/user_guide/tutorials/BUILD @@ -11,16 +11,16 @@ doc( category = "User Guide", permalink = "/docs/tutorials/docker/", subcategory = "Tutorials", - weight = "21", + weight = "10", ) doc( - name = "cni", - src = "cni.md", + name = "docker_compose", + src = "docker-compose.md", category = "User Guide", - permalink = "/docs/tutorials/cni/", + permalink = "/docs/tutorials/docker-compose/", subcategory = "Tutorials", - weight = "22", + weight = "20", ) doc( @@ -33,5 +33,14 @@ doc( ], permalink = "/docs/tutorials/kubernetes/", subcategory = "Tutorials", - weight = "33", + weight = "30", +) + +doc( + name = "cni", + src = "cni.md", + category = "User Guide", + permalink = "/docs/tutorials/cni/", + subcategory = "Tutorials", + weight = "40", ) diff --git a/g3doc/user_guide/tutorials/cni.md b/g3doc/user_guide/tutorials/cni.md index ad6c9fa59..a3507c25b 100644 --- a/g3doc/user_guide/tutorials/cni.md +++ b/g3doc/user_guide/tutorials/cni.md @@ -47,7 +47,7 @@ sudo mkdir -p /etc/cni/net.d sudo sh -c 'cat > /etc/cni/net.d/10-bridge.conf << EOF { - "cniVersion": "0.4.0", + "cniVersion": "0.3.1", "name": "mynet", "type": "bridge", "bridge": "cni0", @@ -65,7 +65,7 @@ EOF' sudo sh -c 'cat > /etc/cni/net.d/99-loopback.conf << EOF { - "cniVersion": "0.4.0", + "cniVersion": "0.3.1", "name": "lo", "type": "loopback" } @@ -128,12 +128,14 @@ sudo mkdir -p rootfs/var/www/html sudo sh -c 'echo "Hello World!" > rootfs/var/www/html/index.html' ``` -Next create the `config.json` specifying the network namespace. `sudo -/usr/local/bin/runsc spec sudo sed -i 's;"sh";"python", "-m", "http.server";' -config.json sudo sed -i "s;\"cwd\": \"/\";\"cwd\": \"/var/www/html\";" -config.json sudo sed -i "s;\"type\": \"network\";\"type\": -\"network\",\n\t\t\t\t\"path\": \"/var/run/netns/${CNI_CONTAINERID}\";" -config.json` +Next create the `config.json` specifying the network namespace. + +``` +sudo /usr/local/bin/runsc spec \ + --cwd /var/www/html \ + --netns /var/run/netns/${CNI_CONTAINERID} \ + -- python -m http.server +``` ## Run the Container diff --git a/g3doc/user_guide/tutorials/docker-compose.md b/g3doc/user_guide/tutorials/docker-compose.md new file mode 100644 index 000000000..3284231f8 --- /dev/null +++ b/g3doc/user_guide/tutorials/docker-compose.md @@ -0,0 +1,100 @@ +# Wordpress with Docker Compose + +This page shows you how to deploy a sample [WordPress][wordpress] site using +[Docker Compose][docker-compose]. + +### Before you begin + +[Follow these instructions][docker-install] to install runsc with Docker. This +document assumes that Docker and Docker Compose are installed and the runtime +name chosen for gVisor is `runsc`. + +### Configuration + +We'll start by creating the `docker-compose.yaml` file to specify our services. +We will specify two services, a `wordpress` service for the Wordpress Apache +server, and a `db` service for MySQL. We will configure Wordpress to connect to +MySQL via the `db` service host name. + +> **Note:** Docker Compose uses it's own network by default and allows services +> to communicate using their service name. Docker Compose does this by setting +> up a DNS server at IP address 127.0.0.11 and configuring containers to use it +> via [resolv.conf][resolv.conf]. This IP is not addressable inside a gVisor +> sandbox so it's important that we set the DNS IP address to the alternative +> `8.8.8.8` and use a network that allows routing to it. See +> [Networking in Compose][compose-networking] for more details. + +> **Note:** The `runtime` field was removed from services in the 3.x version of +> the API in versions of docker-compose < 1.27.0. You will need to write your +> `docker-compose.yaml` file using the 2.x format or use docker-compose >= +> 1.27.0. See this [issue](https://github.com/docker/compose/issues/6239) for +> more details. + +```yaml +version: '2.3' + +services: + db: + image: mysql:5.7 + volumes: + - db_data:/var/lib/mysql + restart: always + environment: + MYSQL_ROOT_PASSWORD: somewordpress + MYSQL_DATABASE: wordpress + MYSQL_USER: wordpress + MYSQL_PASSWORD: wordpress + # All services must be on the same network to communicate. + network_mode: "bridge" + + wordpress: + depends_on: + - db + # When using the "bridge" network specify links. + links: + - db + image: wordpress:latest + ports: + - "8080:80" + restart: always + environment: + WORDPRESS_DB_HOST: db:3306 + WORDPRESS_DB_USER: wordpress + WORDPRESS_DB_PASSWORD: wordpress + WORDPRESS_DB_NAME: wordpress + # Specify the dns address if needed. + dns: + - 8.8.8.8 + # All services must be on the same network to communicate. + network_mode: "bridge" + # Specify the runtime used by Docker. Must be set up in + # /etc/docker/daemon.json. + runtime: "runsc" + +volumes: + db_data: {} +``` + +Once you have a `docker-compose.yaml` in the current directory you can start the +containers: + +```bash +docker-compose up +``` + +Once the containers have started you can access wordpress at +http://localhost:8080. + +Congrats! You now how a working wordpress site up and running using Docker +Compose. + +### What's next + +Learn how to deploy [WordPress with Kubernetes][wordpress-k8s]. + +[docker-compose]: https://docs.docker.com/compose/ +[docker-install]: ../quick_start/docker.md +[wordpress]: https://wordpress.com/ +[resolv.conf]: https://man7.org/linux/man-pages/man5/resolv.conf.5.html +[wordpress-k8s]: kubernetes.md +[compose-networking]: https://docs.docker.com/compose/networking/ diff --git a/g3doc/user_guide/tutorials/docker.md b/g3doc/user_guide/tutorials/docker.md index 705560038..9ca01da2a 100644 --- a/g3doc/user_guide/tutorials/docker.md +++ b/g3doc/user_guide/tutorials/docker.md @@ -60,9 +60,11 @@ Congratulations! You have just deployed a WordPress site using Docker. ### What's next -[Learn how to deploy WordPress with Kubernetes][wordpress-k8s]. +Learn how to deploy WordPress with [Kubernetes][wordpress-k8s] or +[Docker Compose][wordpress-compose]. [docker]: https://www.docker.com/ -[docker-install]: /docs/user_guide/quick_start/docker/ +[docker-install]: ../quick_start/docker.md [wordpress]: https://wordpress.com/ -[wordpress-k8s]: /docs/tutorials/kubernetes/ +[wordpress-k8s]: kubernetes.md +[wordpress-compose]: docker-compose.md diff --git a/g3doc/user_guide/tutorials/kubernetes.md b/g3doc/user_guide/tutorials/kubernetes.md index d2a94b1b7..1ec6e71e9 100644 --- a/g3doc/user_guide/tutorials/kubernetes.md +++ b/g3doc/user_guide/tutorials/kubernetes.md @@ -23,12 +23,12 @@ gcloud beta container node-pools create sandbox-pool --cluster=${CLUSTER_NAME} - If you prefer to use the console, select your cluster and select the **ADD NODE POOL** button: - + Then select the **Image type** with **Containerd** and select **Enable sandbox with gVisor** option. Select other options as you like: - + ### Check that gVisor is enabled @@ -57,47 +57,149 @@ curl -LO https://k8s.io/examples/application/wordpress/mysql-deployment.yaml Add a **spec.template.spec.runtimeClassName** set to **gvisor** to both files, as shown below: -**wordpress-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: -name: wordpress labels: app: wordpress spec: ports: - port: 80 selector: app: -wordpress tier: frontend - -## type: LoadBalancer - -apiVersion: v1 kind: PersistentVolumeClaim metadata: name: wp-pv-claim labels: -app: wordpress spec: accessModes: - ReadWriteOnce resources: requests: - -## storage: 20Gi - -apiVersion: apps/v1 kind: Deployment metadata: name: wordpress labels: app: -wordpress spec: selector: matchLabels: app: wordpress tier: frontend strategy: -type: Recreate template: metadata: labels: app: wordpress tier: frontend spec: -runtimeClassName: gvisor # ADD THIS LINE containers: - image: -wordpress:4.8-apache name: wordpress env: - name: WORDPRESS_DB_HOST value: -wordpress-mysql - name: WORDPRESS_DB_PASSWORD valueFrom: secretKeyRef: name: -mysql-pass key: password ports: - containerPort: 80 name: wordpress -volumeMounts: - name: wordpress-persistent-storage mountPath: /var/www/html -volumes: - name: wordpress-persistent-storage persistentVolumeClaim: claimName: -wp-pv-claim ``` - -**mysql-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: name: -wordpress-mysql labels: app: wordpress spec: ports: - port: 3306 selector: app: -wordpress tier: mysql - -## clusterIP: None - -apiVersion: v1 kind: PersistentVolumeClaim metadata: name: mysql-pv-claim -labels: app: wordpress spec: accessModes: - ReadWriteOnce resources: requests: - -## storage: 20Gi +**wordpress-deployment.yaml:** + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: wordpress + labels: + app: wordpress +spec: + ports: + - port: 80 + selector: + app: wordpress + tier: frontend + type: LoadBalancer +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: wp-pv-claim + labels: + app: wordpress +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 20Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: wordpress + labels: + app: wordpress +spec: + selector: + matchLabels: + app: wordpress + tier: frontend + strategy: + type: Recreate + template: + metadata: + labels: + app: wordpress + tier: frontend + spec: + runtimeClassName: gvisor # ADD THIS LINE + containers: + - image: wordpress:4.8-apache + name: wordpress + env: + - name: WORDPRESS_DB_HOST + value: wordpress-mysql + - name: WORDPRESS_DB_PASSWORD + valueFrom: + secretKeyRef: + name: mysql-pass + key: password + ports: + - containerPort: 80 + name: wordpress + volumeMounts: + - name: wordpress-persistent-storage + mountPath: /var/www/html + volumes: + - name: wordpress-persistent-storage + persistentVolumeClaim: + claimName: wp-pv-claim +``` -apiVersion: apps/v1 kind: Deployment metadata: name: wordpress-mysql labels: -app: wordpress spec: selector: matchLabels: app: wordpress tier: mysql strategy: -type: Recreate template: metadata: labels: app: wordpress tier: mysql spec: -runtimeClassName: gvisor # ADD THIS LINE containers: - image: mysql:5.6 name: -mysql env: - name: MYSQL_ROOT_PASSWORD valueFrom: secretKeyRef: name: mysql-pass -key: password ports: - containerPort: 3306 name: mysql volumeMounts: - name: -mysql-persistent-storage mountPath: /var/lib/mysql volumes: - name: -mysql-persistent-storage persistentVolumeClaim: claimName: mysql-pv-claim ``` +**mysql-deployment.yaml:** + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: wordpress-mysql + labels: + app: wordpress +spec: + ports: + - port: 3306 + selector: + app: wordpress + tier: mysql + clusterIP: None +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: mysql-pv-claim + labels: + app: wordpress +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 20Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: wordpress-mysql + labels: + app: wordpress +spec: + selector: + matchLabels: + app: wordpress + tier: mysql + strategy: + type: Recreate + template: + metadata: + labels: + app: wordpress + tier: mysql + spec: + runtimeClassName: gvisor # ADD THIS LINE + containers: + - image: mysql:5.6 + name: mysql + env: + - name: MYSQL_ROOT_PASSWORD + valueFrom: + secretKeyRef: + name: mysql-pass + key: password + ports: + - containerPort: 3306 + name: mysql + volumeMounts: + - name: mysql-persistent-storage + mountPath: /var/lib/mysql + volumes: + - name: mysql-persistent-storage + persistentVolumeClaim: + claimName: mysql-pv-claim +``` Note that apart from `runtimeClassName: gvisor`, nothing else about the Deployment has is changed. @@ -2,19 +2,51 @@ module gvisor.dev/gvisor go 1.14 +replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0 + require ( - github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 - github.com/golang/protobuf v1.3.1 - github.com/google/btree v1.0.0 - github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 - github.com/kr/pretty v0.2.0 // indirect - github.com/kr/pty v1.1.1 - github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 - github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 - github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e - github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect - golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect + cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 // indirect + github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 // indirect + github.com/Microsoft/hcsshim v0.8.6 // indirect + github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect + github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect + github.com/containerd/containerd v1.3.4 // indirect + github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect + github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect + github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect + github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 // indirect + github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 // indirect + github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect + github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible // indirect + github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect + github.com/docker/go-connections v0.3.0 // indirect + github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect + github.com/docker/go-units v0.4.0 // indirect + github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b // indirect + github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect + github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect + github.com/gogo/googleapis v1.4.0 // indirect + github.com/golang/protobuf v1.4.2 // indirect + github.com/google/go-cmp v0.5.0 // indirect + github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect + github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect + github.com/hashicorp/go-multierror v1.0.0 // indirect + github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect + github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.0.1 // indirect + github.com/opencontainers/runc v0.1.1 // indirect + github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect + github.com/pborman/uuid v1.2.0 // indirect + github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect + github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5 // indirect + github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect + github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe // indirect + go.uber.org/atomic v1.6.0 // indirect + go.uber.org/multierr v1.2.0 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a // indirect + google.golang.org/grpc v1.29.0 // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect + gotest.tools v2.2.0+incompatible // indirect ) @@ -1,32 +1,387 @@ -github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8= -github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= +bazil.org/fuse v0.0.0-20160811212531-371fbbdaa898/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8= +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= +cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= +cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= +cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo= +cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= +cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Microsoft/go-winio v0.4.14 h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU= +github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= +github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA= +github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= +github.com/Microsoft/hcsshim v0.8.6/go.mod h1:Op3hHsoHPAvb6lceZHDtd9OkTew38wNoXnJs8iY7rUg= +github.com/Microsoft/hcsshim v0.8.7/go.mod h1:OHd7sQqRFrYd3RmSgbgji+ctCwkbq2wbEYNSzOYtcBQ= +github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= +github.com/Microsoft/hcsshim v0.8.9 h1:VrfodqvztU8YSOvygU+DN1BGaSGxmrNfqOv5oOuX2Bk= +github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= +github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4= +github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY= +github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI= +github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f h1:tSNMc+rJDfmYntojat8lljbt1mgKNpTxUZJsSzJ9Y1s= +github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f/go.mod h1:OApqhQ4XNSNC13gXIwDjhOQxjWa/NxkwZXJ1EvqT0ko= +github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw= +github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc= +github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE= +github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= +github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI= +github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= +github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= +github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k= +github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe/go.mod h1:cECdGN1O8G9bgKTlLhuPJimka6Xb/Gg7vYzCTNVxhvo= +github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI= +github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw= +github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0= +github.com/containerd/go-runc v0.0.0-20180907222934-5a6d9f37cfa3/go.mod h1:IV7qH3hrUgRmyYrtgEeGWJfWbgcHL9CSRruz2Vqcph0= +github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds= +github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328/go.mod h1:PpyHrqVs8FTi9vpyHwPwiNEGaACDxT/N/pLcvMSRA9g= +github.com/containerd/ttrpc v0.0.0-20190828154514-0e0f228740de/go.mod h1:PvCDdDGpgqzQIzDW1TphrGLssLDZp2GuS+X5DkEJB8o= +github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc= +github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15/go.mod h1:UAxOpgT9ziI0gJrmKvgcZivgxOp8iFPSk8httJEt98Y= +github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= +github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o= +github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= +github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w= +github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.3.0 h1:3lOnM9cSzgGwx8VfK/NGOW5fLQ0GjIlCkaktF+n1M6o= +github.com/docker/go-connections v0.3.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ+oDZB4KHQFypsfjYlq/C4rfL7D3g8= +github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= +github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U= +github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= +github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= +github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= -github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= +github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI= +github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI= -github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= +github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM= +github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= +github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M= +github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI= -github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c= +github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= +github.com/opencontainers/go-digest v1.0.0-rc1/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI= +github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= +github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= +github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y= +github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= +github.com/opencontainers/runtime-spec v0.1.2-0.20190507144316-5b71a03e2700/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8= +github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= -github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q= -github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= -github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk= -github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= -golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w= +github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= +github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so= +github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= +go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= +go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= +go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/exp v0.0.0-20191227195350-da58074b4299 h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg= +golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190514135907-3a4b5fb9f71f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o= +golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE= +golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= +google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM= +google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo= +google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/images/Makefile b/images/Makefile index 1485607bd..d4b6524ba 100644 --- a/images/Makefile +++ b/images/Makefile @@ -34,8 +34,15 @@ list-all-images: @for image in $(ALL_IMAGES); do echo $${image}; done .PHONY: list-build-images +# Handy wrapper to allow load-all-images, push-all-images, etc. %-all-images: @$(MAKE) $(patsubst %,$*-%,$(ALL_IMAGES)) +load-all-images: + @$(MAKE) $(patsubst %,load-%,$(ALL_IMAGES)) + +# Handy wrapper to load specified "groups", e.g. load-basic-images, etc. +load-%-images: + @$(MAKE) $(patsubst %,load-%,$(subst /,_,$(subst ./,,$(shell find ./$* -name Dockerfile -exec dirname {} \;)))) # tag is a function that returns the tag name, given an image. # @@ -52,9 +59,9 @@ local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1)) # we need to explicitly repull the base layer in order to ensure that the # architecture is correct. Note that we use the term "rebuild" here to avoid # conflicting with the bazel "build" terminology, which is used elsewhere. +rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile | cut -d' ' -f2) rebuild-%: register-cross - FROM=$(shell grep FROM $(call path,$*)/Dockerfile | cut -d' ' -f2-) && \ - docker pull $(DOCKER_PLATFORM_ARGS) $$FROM + $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \ T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \ docker build $(DOCKER_PLATFORM_ARGS) -t $(call remote_image,$*) $$T && \ rm -rf $$T @@ -66,10 +73,10 @@ pull-%: docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*) # load will either pull the "remote" or build it locally. This is the preferred -# entrypoint, as it should never file. The local tag should always be set after +# entrypoint, as it should never fail. The local tag should always be set after # this returns (either by the pull or the build). load-%: - docker inspect $(call remote_image,$*) >/dev/null 2>&1 || $(MAKE) pull-$* || $(MAKE) rebuild-$* + $(MAKE) pull-$* || $(MAKE) rebuild-$* docker tag $(call remote_image,$*) $(call local_image,$*) # push pushes the remote image, after either pulling (to validate that the tag diff --git a/images/README.md b/images/README.md index d2efb5db4..9880946a6 100644 --- a/images/README.md +++ b/images/README.md @@ -7,7 +7,7 @@ Note that all these images must be pushed to the testing project hosted on continuous integration. This will speed up loading as images will not need to be built from scratch for each test run. -Image tooling is accessible via `make`, specifically via `tools/images.mk`. +Image tooling is accessible via `make`, specifically via `images/Makefile`. ## Why make? @@ -59,3 +59,12 @@ project. The continuous integration system can either take fine-grained dependencies on individual `push` targets, or ensure all images are up-to-date with a single `push-all-images` invocation. + +## Multi-Arch images + +By default, the image is built for host architecture. Cross-building can be +achieved by specifying `ARCH` variable to make. For example: + +``` +$ make ARCH=aarch64 rebuild-default +``` diff --git a/images/basic/hostoverlaytest/Dockerfile b/images/basic/hostoverlaytest/Dockerfile new file mode 100644 index 000000000..6cef1a542 --- /dev/null +++ b/images/basic/hostoverlaytest/Dockerfile @@ -0,0 +1,8 @@ +FROM ubuntu:bionic + +WORKDIR /root +COPY . . + +RUN apt-get update && apt-get install -y gcc +RUN gcc -O2 -o test_copy_up test_copy_up.c +RUN gcc -O2 -o test_rewinddir test_rewinddir.c diff --git a/images/hostoverlaytest/testfile.txt b/images/basic/hostoverlaytest/copy_up_testfile.txt index e4188c841..e4188c841 100644 --- a/images/hostoverlaytest/testfile.txt +++ b/images/basic/hostoverlaytest/copy_up_testfile.txt diff --git a/images/hostoverlaytest/test.c b/images/basic/hostoverlaytest/test_copy_up.c index 088f90746..010b261dc 100644 --- a/images/hostoverlaytest/test.c +++ b/images/basic/hostoverlaytest/test_copy_up.c @@ -6,7 +6,7 @@ #include <unistd.h> int main(int argc, char** argv) { - const char kTestFilePath[] = "testfile.txt"; + const char kTestFilePath[] = "copy_up_testfile.txt"; const char kOldFileData[] = "old data\n"; const char kNewFileData[] = "new data\n"; const size_t kPageSize = sysconf(_SC_PAGE_SIZE); diff --git a/images/basic/hostoverlaytest/test_rewinddir.c b/images/basic/hostoverlaytest/test_rewinddir.c new file mode 100644 index 000000000..f1a4085e1 --- /dev/null +++ b/images/basic/hostoverlaytest/test_rewinddir.c @@ -0,0 +1,78 @@ +#include <dirent.h> +#include <err.h> +#include <errno.h> +#include <stdlib.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> + +int main(int argc, char** argv) { + const char kDirPath[] = "rewinddir_test_dir"; + const char kFileBasename[] = "rewinddir_test_file"; + + // Create the test directory. + if (mkdir(kDirPath, 0755) < 0) { + err(1, "mkdir(%s)", kDirPath); + } + + // The test directory should initially be empty. + DIR* dir = opendir(kDirPath); + if (!dir) { + err(1, "opendir(%s)", kDirPath); + } + int failed = 0; + while (1) { + errno = 0; + struct dirent* d = readdir(dir); + if (!d) { + if (errno != 0) { + err(1, "readdir"); + } + break; + } + if (strcmp(d->d_name, ".") != 0 && strcmp(d->d_name, "..") != 0) { + warnx("unexpected file %s in new directory", d->d_name); + failed = 1; + } + } + + // Create a file in the test directory. + char* file_path = malloc(strlen(kDirPath) + 1 + strlen(kFileBasename)); + if (!file_path) { + errx(1, "malloc"); + } + strcpy(file_path, kDirPath); + file_path[strlen(kDirPath)] = '/'; + strcpy(file_path + strlen(kDirPath) + 1, kFileBasename); + if (mknod(file_path, 0644, 0) < 0) { + err(1, "mknod(%s)", file_path); + } + + // After rewinddir(), re-reading the directory stream should yield the new + // file. + rewinddir(dir); + size_t found_file = 0; + while (1) { + errno = 0; + struct dirent* d = readdir(dir); + if (!d) { + if (errno != 0) { + err(1, "readdir"); + } + break; + } + if (strcmp(d->d_name, kFileBasename) == 0) { + found_file++; + } else if (strcmp(d->d_name, ".") != 0 && strcmp(d->d_name, "..") != 0) { + warnx("unexpected file %s in new directory", d->d_name); + failed = 1; + } + } + if (found_file != 1) { + warnx("readdir returned file %s %zu times, wanted 1", kFileBasename, + found_file); + failed = 1; + } + + return failed; +} diff --git a/images/hostoverlaytest/Dockerfile b/images/basic/linktest/Dockerfile index d83439e9c..baebc9b76 100644 --- a/images/hostoverlaytest/Dockerfile +++ b/images/basic/linktest/Dockerfile @@ -4,4 +4,4 @@ WORKDIR /root COPY . . RUN apt-get update && apt-get install -y gcc -RUN gcc -O2 -o test test.c +RUN gcc -O2 -o link_test link_test.c diff --git a/images/basic/linktest/link_test.c b/images/basic/linktest/link_test.c new file mode 100644 index 000000000..45ab00abe --- /dev/null +++ b/images/basic/linktest/link_test.c @@ -0,0 +1,93 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <err.h> +#include <fcntl.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it +// cannot use tricks like userns as root. For this reason, run a basic link test +// to ensure some coverage. +int main(int argc, char** argv) { + const char kOldPath[] = "old.txt"; + int fd = open(kOldPath, O_RDWR | O_CREAT | O_TRUNC, 0600); + if (fd < 0) { + errx(1, "open(%s) failed", kOldPath); + } + const char kData[] = "some random content"; + if (write(fd, kData, sizeof(kData)) < 0) { + err(1, "write failed"); + } + close(fd); + + struct stat old_stat; + if (stat(kOldPath, &old_stat)) { + errx(1, "stat(%s) failed", kOldPath); + } + + const char kNewPath[] = "new.txt"; + if (link(kOldPath, kNewPath)) { + errx(1, "link(%s, %s) failed", kOldPath, kNewPath); + } + + struct stat new_stat; + if (stat(kNewPath, &new_stat)) { + errx(1, "stat(%s) failed", kNewPath); + } + + // Check that files are the same. + if (old_stat.st_dev != new_stat.st_dev) { + errx(1, "files st_dev is different, want: %lu, got: %lu", old_stat.st_dev, + new_stat.st_dev); + } + if (old_stat.st_ino != new_stat.st_ino) { + errx(1, "files st_ino is different, want: %lu, got: %lu", old_stat.st_ino, + new_stat.st_ino); + } + + // Check that link count is correct. + if (new_stat.st_nlink != old_stat.st_nlink + 1) { + errx(1, "wrong nlink, want: %lu, got: %lu", old_stat.st_nlink + 1, + new_stat.st_nlink); + } + + // Check taht contents are the same. + fd = open(kNewPath, O_RDONLY); + if (fd < 0) { + errx(1, "open(%s) failed", kNewPath); + } + char buf[sizeof(kData)] = {}; + if (read(fd, buf, sizeof(buf)) < 0) { + err(1, "read failed"); + } + close(fd); + + if (strcmp(buf, kData) != 0) { + errx(1, "file content mismatch: %s", buf); + } + + // Cleanup. + if (unlink(kNewPath)) { + errx(1, "unlink(%s) failed", kNewPath); + } + if (unlink(kOldPath)) { + errx(1, "unlink(%s) failed", kOldPath); + } + + // Success! + return 0; +} diff --git a/images/tmpfile/Dockerfile b/images/basic/tmpfile/Dockerfile index e3816c8cb..e3816c8cb 100644 --- a/images/tmpfile/Dockerfile +++ b/images/basic/tmpfile/Dockerfile diff --git a/images/benchmarks/ab/Dockerfile b/images/benchmarks/ab/Dockerfile new file mode 100644 index 000000000..10544639b --- /dev/null +++ b/images/benchmarks/ab/Dockerfile @@ -0,0 +1,7 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2-utils \ + && rm -rf /var/lib/apt/lists/* diff --git a/benchmarks/workloads/absl/Dockerfile b/images/benchmarks/absl/Dockerfile index f29cfa156..b0dd97695 100644 --- a/benchmarks/workloads/absl/Dockerfile +++ b/images/benchmarks/absl/Dockerfile @@ -19,7 +19,3 @@ RUN ./bazel-0.27.0-installer-linux-x86_64.sh RUN mkdir abseil-cpp && cd abseil-cpp \ && git init && git remote add origin https://github.com/abseil/abseil-cpp.git \ && git fetch --depth 1 origin 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f && git checkout FETCH_HEAD -WORKDIR abseil-cpp -RUN bazel clean -ENV path "absl/base/..." -CMD bazel build ${path} 2>&1 diff --git a/benchmarks/workloads/true/Dockerfile b/images/benchmarks/alpine/Dockerfile index 2e97c921e..b09b037ca 100644 --- a/benchmarks/workloads/true/Dockerfile +++ b/images/benchmarks/alpine/Dockerfile @@ -1,3 +1 @@ FROM alpine:latest - -CMD ["true"] diff --git a/benchmarks/workloads/ffmpeg/Dockerfile b/images/benchmarks/ffmpeg/Dockerfile index f2f530d7c..7108df64f 100644 --- a/benchmarks/workloads/ffmpeg/Dockerfile +++ b/images/benchmarks/ffmpeg/Dockerfile @@ -7,4 +7,3 @@ RUN set -x \ && rm -rf /var/lib/apt/lists/* WORKDIR /media ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4 -CMD ["ffmpeg", "-i", "video.mp4", "-c:v", "libx264", "-preset", "veryslow", "output.mp4"] diff --git a/images/benchmarks/fio/Dockerfile b/images/benchmarks/fio/Dockerfile new file mode 100644 index 000000000..9531df7fa --- /dev/null +++ b/images/benchmarks/fio/Dockerfile @@ -0,0 +1,7 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + fio \ + && rm -rf /var/lib/apt/lists/* diff --git a/images/benchmarks/hey/Dockerfile b/images/benchmarks/hey/Dockerfile new file mode 100644 index 000000000..f586978b6 --- /dev/null +++ b/images/benchmarks/hey/Dockerfile @@ -0,0 +1,12 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + && rm -rf /var/lib/apt/lists/* + +RUN wget https://storage.googleapis.com/hey-release/hey_linux_amd64 \ + && chmod 777 hey_linux_amd64 \ + && cp hey_linux_amd64 /bin/hey \ + && rm hey_linux_amd64 diff --git a/benchmarks/workloads/httpd/Dockerfile b/images/benchmarks/httpd/Dockerfile index 52a550678..e95538a40 100644 --- a/benchmarks/workloads/httpd/Dockerfile +++ b/images/benchmarks/httpd/Dockerfile @@ -8,20 +8,10 @@ RUN set -x \ # Generate a bunch of relevant files. RUN mkdir -p /local && \ - for size in 1 10 100 1000 1024 10240; do \ + for size in 1 10 100 1024 10240; do \ dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ done # Rewrite DocumentRoot to point to /tmp/html instead of the default path. RUN sed -i 's/DocumentRoot.*\/var\/www\/html$/DocumentRoot \/tmp\/html/' /etc/apache2/sites-enabled/000-default.conf COPY ./apache2-tmpdir.conf /etc/apache2/sites-enabled/apache2-tmpdir.conf - -# Standard settings. -ENV APACHE_RUN_DIR /tmp -ENV APACHE_RUN_USER nobody -ENV APACHE_RUN_GROUP nogroup -ENV APACHE_LOG_DIR /tmp -ENV APACHE_PID_FILE /tmp/apache.pid - -# Copy on start-up; serve everything from /tmp (including the configuration). -CMD ["sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && apache2 -X"] diff --git a/benchmarks/workloads/httpd/apache2-tmpdir.conf b/images/benchmarks/httpd/apache2-tmpdir.conf index e33f8d9bb..e33f8d9bb 100644 --- a/benchmarks/workloads/httpd/apache2-tmpdir.conf +++ b/images/benchmarks/httpd/apache2-tmpdir.conf diff --git a/images/benchmarks/iperf/Dockerfile b/images/benchmarks/iperf/Dockerfile new file mode 100644 index 000000000..4cbfd0d70 --- /dev/null +++ b/images/benchmarks/iperf/Dockerfile @@ -0,0 +1,8 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + iperf \ + && rm -rf /var/lib/apt/lists/* + diff --git a/images/benchmarks/nginx/Dockerfile b/images/benchmarks/nginx/Dockerfile new file mode 100644 index 000000000..c8e3330d0 --- /dev/null +++ b/images/benchmarks/nginx/Dockerfile @@ -0,0 +1,12 @@ +FROM nginx:1.15.10 + +# Generate a bunch of relevant files. +RUN mkdir -p /local && \ + for size in 1 10 100 1024 10240; do \ + dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ + done + +RUN touch /local/index.html + +COPY ./nginx.conf /etc/nginx/nginx.conf +COPY ./nginx_gofer.conf /etc/nginx/nginx_gofer.conf diff --git a/images/benchmarks/nginx/nginx.conf b/images/benchmarks/nginx/nginx.conf new file mode 100644 index 000000000..2c43c0cda --- /dev/null +++ b/images/benchmarks/nginx/nginx.conf @@ -0,0 +1,19 @@ +user nginx; +worker_processes 1; +daemon off; + +error_log /var/log/nginx/error.log warn; +pid /var/run/nginx.pid; + +events { + worker_connections 1024; +} + + +http { + server { + location / { + root /tmp/html; + } + } +} diff --git a/images/benchmarks/nginx/nginx_gofer.conf b/images/benchmarks/nginx/nginx_gofer.conf new file mode 100644 index 000000000..dbba2a575 --- /dev/null +++ b/images/benchmarks/nginx/nginx_gofer.conf @@ -0,0 +1,19 @@ +user nginx; +worker_processes 1; +daemon off; + +error_log /var/log/nginx/error.log warn; +pid /var/run/nginx.pid; + +events { + worker_connections 1024; +} + + +http { + server { + location / { + root /local; + } + } +} diff --git a/images/benchmarks/node/Dockerfile b/images/benchmarks/node/Dockerfile new file mode 100644 index 000000000..bf45650a0 --- /dev/null +++ b/images/benchmarks/node/Dockerfile @@ -0,0 +1 @@ +FROM node:onbuild diff --git a/benchmarks/workloads/node_template/index.hbs b/images/benchmarks/node/index.hbs index 03feceb75..03feceb75 100644 --- a/benchmarks/workloads/node_template/index.hbs +++ b/images/benchmarks/node/index.hbs diff --git a/benchmarks/workloads/node_template/index.js b/images/benchmarks/node/index.js index 04a27f356..831015d18 100644 --- a/benchmarks/workloads/node_template/index.js +++ b/images/benchmarks/node/index.js @@ -19,7 +19,6 @@ app.get('/', (req, res) => { tmp.push(reply.toString()); }); } - res.render('index', {text: tmp}); }); diff --git a/benchmarks/workloads/node_template/package-lock.json b/images/benchmarks/node/package-lock.json index 580e68aa5..580e68aa5 100644 --- a/benchmarks/workloads/node_template/package-lock.json +++ b/images/benchmarks/node/package-lock.json diff --git a/benchmarks/workloads/node_template/package.json b/images/benchmarks/node/package.json index 7dcadd523..7dcadd523 100644 --- a/benchmarks/workloads/node_template/package.json +++ b/images/benchmarks/node/package.json diff --git a/benchmarks/workloads/redis/Dockerfile b/images/benchmarks/redis/Dockerfile index 0f17249af..0f17249af 100644 --- a/benchmarks/workloads/redis/Dockerfile +++ b/images/benchmarks/redis/Dockerfile diff --git a/benchmarks/workloads/ruby_template/Dockerfile b/images/benchmarks/ruby/Dockerfile index a06d68bf4..13c4f6eed 100755 --- a/benchmarks/workloads/ruby_template/Dockerfile +++ b/images/benchmarks/ruby/Dockerfile @@ -1,5 +1,4 @@ # example based on https://github.com/errm/fib - FROM alpine:3.9 as build COPY Gemfile Gemfile.lock ./ @@ -23,16 +22,6 @@ RUN apk add --no-cache ruby ruby-json ruby-etc redis apache2-utils \ ).generate_bin \ end" -WORKDIR /app COPY . /app/. -ENV PORT=9292 \ - WEB_CONCURRENCY=20 \ - WEB_MAX_THREADS=20 \ - RACK_ENV=production - -ENV host localhost -EXPOSE $PORT -USER nobody STOPSIGNAL SIGINT -CMD ["sh", "-c", "/usr/bin/puma", "${host}"] diff --git a/benchmarks/workloads/ruby_template/Gemfile b/images/benchmarks/ruby/Gemfile index ac521b32c..ac521b32c 100755 --- a/benchmarks/workloads/ruby_template/Gemfile +++ b/images/benchmarks/ruby/Gemfile diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/images/benchmarks/ruby/Gemfile.lock index eeb3c7bbe..041778e02 100644 --- a/benchmarks/workloads/ruby_template/Gemfile.lock +++ b/images/benchmarks/ruby/Gemfile.lock @@ -2,7 +2,7 @@ GEM remote: https://rubygems.org/ specs: mustermann (1.0.3) - puma (3.12.6) + puma (3.4.0) rack (2.0.6) rack-protection (2.0.5) rack diff --git a/benchmarks/workloads/ruby_template/config.ru b/images/benchmarks/ruby/config.ru index b2d135cc0..b2d135cc0 100755 --- a/benchmarks/workloads/ruby_template/config.ru +++ b/images/benchmarks/ruby/config.ru diff --git a/benchmarks/workloads/ruby_template/index.erb b/images/benchmarks/ruby/index.erb index 7f7300e80..7f7300e80 100755 --- a/benchmarks/workloads/ruby_template/index.erb +++ b/images/benchmarks/ruby/index.erb diff --git a/benchmarks/workloads/ruby_template/main.rb b/images/benchmarks/ruby/main.rb index 35c239377..b998f004e 100755 --- a/benchmarks/workloads/ruby_template/main.rb +++ b/images/benchmarks/ruby/main.rb @@ -2,7 +2,7 @@ require "sinatra" require "securerandom" require "redis" -redis_host = ENV["host"] +redis_host = ENV["HOST"] $redis = Redis.new(host: redis_host) def generateText @@ -24,4 +24,4 @@ get "/" do texts.push($redis.get(rand(0..99))) end template.result_with_hash(text: texts) -end
\ No newline at end of file +end diff --git a/images/benchmarks/runsc/Dockerfile b/images/benchmarks/runsc/Dockerfile new file mode 100644 index 000000000..6c3aafa57 --- /dev/null +++ b/images/benchmarks/runsc/Dockerfile @@ -0,0 +1,24 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + git \ + pkg-config \ + zip \ + g++ \ + zlib1g-dev \ + unzip \ + python-minimal \ + python3 \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* +RUN wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-installer-linux-x86_64.sh +RUN chmod +x bazel-3.4.1-installer-linux-x86_64.sh +RUN ./bazel-3.4.1-installer-linux-x86_64.sh + +# Download release-20200601.0 +RUN mkdir gvisor && cd gvisor \ + && git init && git remote add origin https://github.com/google/gvisor.git \ + && git fetch --depth 1 origin a9b47390c821942d60784e308f681f213645049c && git checkout FETCH_HEAD diff --git a/images/benchmarks/sysbench/Dockerfile b/images/benchmarks/sysbench/Dockerfile new file mode 100644 index 000000000..55e865f43 --- /dev/null +++ b/images/benchmarks/sysbench/Dockerfile @@ -0,0 +1,7 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + sysbench \ + && rm -rf /var/lib/apt/lists/* diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/images/benchmarks/tensorflow/Dockerfile index b5763e8ae..7564a4ee5 100644 --- a/benchmarks/workloads/tensorflow/Dockerfile +++ b/images/benchmarks/tensorflow/Dockerfile @@ -5,10 +5,3 @@ RUN apt-get update \ RUN git clone --depth 1 https://github.com/aymericdamien/TensorFlow-Examples.git RUN python -m pip install -U pip setuptools RUN python -m pip install matplotlib - -WORKDIR /TensorFlow-Examples/examples - -ENV PYTHONPATH="$PYTHONPATH:/TensorFlow-Examples/examples" - -ENV workload "3_NeuralNetworks/convolutional_network.py" -CMD python ${workload} diff --git a/images/benchmarks/util/Dockerfile b/images/benchmarks/util/Dockerfile new file mode 100644 index 000000000..f2799b3e6 --- /dev/null +++ b/images/benchmarks/util/Dockerfile @@ -0,0 +1,3 @@ +FROM ubuntu:bionic + +RUN apt-get update && apt-get install -y wget diff --git a/images/default/Dockerfile b/images/default/Dockerfile index 397082b02..d058b83cb 100644 --- a/images/default/Dockerfile +++ b/images/default/Dockerfile @@ -1,8 +1,8 @@ FROM fedora:31 # Install bazel. RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel -RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch -RUN pip install pycparser +RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils +RUN pip install --no-cache-dir pycparser RUN dnf install -y bazel3 # Install gcloud. RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \ diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile index 4860dd750..ae19f3bfc 100644 --- a/images/jekyll/Dockerfile +++ b/images/jekyll/Dockerfile @@ -1,5 +1,6 @@ FROM jekyll/jekyll:4.0.0 USER root + RUN gem install \ html-proofer:3.10.2 \ nokogiri:1.10.1 \ @@ -10,4 +11,9 @@ RUN gem install \ jekyll-relative-links:0.6.1 \ jekyll-feed:0.13.0 \ jekyll-sitemap:1.4.0 -CMD ["/usr/gem/gems/jekyll-4.0.0/exe/jekyll", "build", "-t", "-s", "/input", "-d", "/output"] + +# checks.rb is used with html-proofer for presubmit checks. +COPY checks.rb /checks.rb + +COPY build.sh /build.sh +CMD ["/build.sh"] diff --git a/scripts/make_tests.sh b/images/jekyll/build.sh index dbf1bba77..010972ea6 100755 --- a/scripts/make_tests.sh +++ b/images/jekyll/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright 2019 The gVisor Authors. +# Copyright 2020 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -source $(dirname $0)/common.sh +set -euxo pipefail -make runsc -make bazel-shutdown +# Generate the syntax highlighting css file. +/usr/gem/bin/rougify style github >/input/_sass/syntax.css +# Build website including pages irrespective of date. +/usr/gem/bin/jekyll build --future -t -s /input -d /output diff --git a/images/jekyll/checks.rb b/images/jekyll/checks.rb new file mode 100644 index 000000000..fc7e6b5a8 --- /dev/null +++ b/images/jekyll/checks.rb @@ -0,0 +1,36 @@ +#!/usr/local/bin/ruby +# +# HTMLProofer checks for the gVisor website. +# +require 'html-proofer' + +# NoOpenerCheck checks to make sure links with target=_blank include the +# rel=noopener attribute. +class NoOpenerCheck < ::HTMLProofer::Check + def run + @html.css('a').each do |node| + link = create_element(node) + line = node.line + + rel = link.respond_to?(:rel) ? link.rel.split(' ') : [] + + if link.respond_to?(:target) && link.target == "_blank" && !rel.include?("noopener") + return add_issue("You should set rel=noopener for links with target=_blank", line: line) + end + end + end +end + +def main() + options = { + :check_html => true, + :check_favicon => true, + :disable_external => true, + } + + HTMLProofer.check_directories(ARGV, options).run +end + +if __FILE__ == $0 + main +end diff --git a/images/packetdrill/Dockerfile b/images/packetdrill/Dockerfile index 01296dbaf..b4cd73006 100644 --- a/images/packetdrill/Dockerfile +++ b/images/packetdrill/Dockerfile @@ -1,8 +1,8 @@ FROM ubuntu:bionic RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \ netcat tcpdump jq tar bison flex make +# Pick up updated git. RUN hash -r RUN git clone --depth 1 --branch packetdrill-v2.0 \ https://github.com/google/packetdrill.git RUN cd packetdrill/gtests/net/packetdrill && ./configure && make -CMD /bin/bash diff --git a/images/packetimpact/Dockerfile b/images/packetimpact/Dockerfile index 87aa99ef2..906d5cdd6 100644 --- a/images/packetimpact/Dockerfile +++ b/images/packetimpact/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:bionic +FROM ubuntu:focal RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ # iptables to disable OS native packet processing. iptables \ @@ -11,6 +11,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ # tshark to log verbose packet sniffing. tshark \ # killall for cleanup. - psmisc -RUN hash -r -CMD /bin/bash + psmisc \ + # qemu-system-x86 to emulate fuchsia. + qemu-system-x86 \ + # sha1sum to generate entropy. + libdigest-sha-perl diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 2b789c4ec..cdcaa8c73 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -29,6 +29,7 @@ go_library( "file_amd64.go", "file_arm64.go", "fs.go", + "fuse.go", "futex.go", "inotify.go", "ioctl.go", @@ -40,6 +41,7 @@ go_library( "mm.go", "netdevice.go", "netfilter.go", + "netfilter_ipv6.go", "netlink.go", "netlink_route.go", "poll.go", @@ -72,6 +74,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/marshal", + "//pkg/marshal/primitive", + "//pkg/usermem", ], ) diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go index 86ee3f8b5..5fc099892 100644 --- a/pkg/abi/linux/aio.go +++ b/pkg/abi/linux/aio.go @@ -42,6 +42,8 @@ const ( // // The priority field is currently ignored in the implementation below. Also // note that the IOCB_FLAG_RESFD feature is not supported. +// +// +marshal type IOCallback struct { Data uint64 Key uint32 @@ -64,6 +66,7 @@ type IOCallback struct { // IOEvent describes an I/O result. // +// +marshal // +stateify savable type IOEvent struct { Data uint64 diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go index aa3d3ce70..9422fcf69 100644 --- a/pkg/abi/linux/bpf.go +++ b/pkg/abi/linux/bpf.go @@ -16,6 +16,7 @@ package linux // BPFInstruction is a raw BPF virtual machine instruction. // +// +marshal slice:BPFInstructionSlice // +stateify savable type BPFInstruction struct { // OpCode is the operation to execute. diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go index 965f74663..afd16cc27 100644 --- a/pkg/abi/linux/capability.go +++ b/pkg/abi/linux/capability.go @@ -177,12 +177,16 @@ const ( ) // CapUserHeader is equivalent to Linux's cap_user_header_t. +// +// +marshal type CapUserHeader struct { Version uint32 Pid int32 } // CapUserData is equivalent to Linux's cap_user_data_t. +// +// +marshal slice:CapUserDataSlice type CapUserData struct { Effective uint32 Permitted uint32 diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go index 192e2093b..7771650b3 100644 --- a/pkg/abi/linux/dev.go +++ b/pkg/abi/linux/dev.go @@ -54,9 +54,9 @@ const ( // Unix98 PTY masters. UNIX98_PTY_MASTER_MAJOR = 128 - // UNIX98_PTY_SLAVE_MAJOR is the initial major device number for - // Unix98 PTY slaves. - UNIX98_PTY_SLAVE_MAJOR = 136 + // UNIX98_PTY_REPLICA_MAJOR is the initial major device number for + // Unix98 PTY replicas. + UNIX98_PTY_REPLICA_MAJOR = 136 ) // Minor device numbers for TTYAUX_MAJOR. diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index 9242e80a5..cc3571fad 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -45,6 +45,8 @@ const ( ) // Flock is the lock structure for F_SETLK. +// +// +marshal type Flock struct { Type int16 Whence int16 @@ -63,6 +65,8 @@ const ( ) // FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +// +// +marshal type FOwnerEx struct { Type int32 PID int32 diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 158d2db5b..0d921ed6f 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -29,6 +29,7 @@ const ( SYSFS_MAGIC = 0x62656572 TMPFS_MAGIC = 0x01021994 V9FS_MAGIC = 0x01021997 + FUSE_SUPER_MAGIC = 0x65735546 ) // Filesystem path limits, from uapi/linux/limits.h. @@ -44,17 +45,18 @@ type Statfs struct { // Type is one of the filesystem magic values, defined above. Type uint64 - // BlockSize is the data block size. + // BlockSize is the optimal transfer block size in bytes. BlockSize int64 - // Blocks is the number of data blocks in use. + // Blocks is the maximum number of data blocks the filesystem may store, in + // units of BlockSize. Blocks uint64 - // BlocksFree is the number of free blocks. + // BlocksFree is the number of free data blocks, in units of BlockSize. BlocksFree uint64 - // BlocksAvailable is the number of blocks free for use by - // unprivileged users. + // BlocksAvailable is the number of data blocks free for use by + // unprivileged users, in units of BlockSize. BlocksAvailable uint64 // Files is the number of used file nodes on the filesystem. diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go new file mode 100644 index 000000000..d91c97a64 --- /dev/null +++ b/pkg/abi/linux/fuse.go @@ -0,0 +1,873 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// +marshal +type FUSEOpcode uint32 + +// +marshal +type FUSEOpID uint64 + +// FUSE_ROOT_ID is the id of root inode. +const FUSE_ROOT_ID = 1 + +// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +const ( + FUSE_LOOKUP FUSEOpcode = 1 + FUSE_FORGET = 2 /* no reply */ + FUSE_GETATTR = 3 + FUSE_SETATTR = 4 + FUSE_READLINK = 5 + FUSE_SYMLINK = 6 + _ + FUSE_MKNOD = 8 + FUSE_MKDIR = 9 + FUSE_UNLINK = 10 + FUSE_RMDIR = 11 + FUSE_RENAME = 12 + FUSE_LINK = 13 + FUSE_OPEN = 14 + FUSE_READ = 15 + FUSE_WRITE = 16 + FUSE_STATFS = 17 + FUSE_RELEASE = 18 + _ + FUSE_FSYNC = 20 + FUSE_SETXATTR = 21 + FUSE_GETXATTR = 22 + FUSE_LISTXATTR = 23 + FUSE_REMOVEXATTR = 24 + FUSE_FLUSH = 25 + FUSE_INIT = 26 + FUSE_OPENDIR = 27 + FUSE_READDIR = 28 + FUSE_RELEASEDIR = 29 + FUSE_FSYNCDIR = 30 + FUSE_GETLK = 31 + FUSE_SETLK = 32 + FUSE_SETLKW = 33 + FUSE_ACCESS = 34 + FUSE_CREATE = 35 + FUSE_INTERRUPT = 36 + FUSE_BMAP = 37 + FUSE_DESTROY = 38 + FUSE_IOCTL = 39 + FUSE_POLL = 40 + FUSE_NOTIFY_REPLY = 41 + FUSE_BATCH_FORGET = 42 +) + +const ( + // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem. + // This is the minimum size Linux supports. See linux.fuse.h. + FUSE_MIN_READ_BUFFER uint32 = 8192 +) + +// FUSEHeaderIn is the header read by the daemon with each request. +// +// +marshal +type FUSEHeaderIn struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Opcode specifies the kind of operation of the request. + Opcode FUSEOpcode + + // Unique specifies the unique identifier for this request. + Unique FUSEOpID + + // NodeID is the ID of the filesystem object being operated on. + NodeID uint64 + + // UID is the UID of the requesting process. + UID uint32 + + // GID is the GID of the requesting process. + GID uint32 + + // PID is the PID of the requesting process. + PID uint32 + + _ uint32 +} + +// FUSEHeaderOut is the header written by the daemon when it processes +// a request and wants to send a reply (almost all operations require a +// reply; if they do not, this will be explicitly documented). +// +// +marshal +type FUSEHeaderOut struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Error specifies the error that occurred (0 if none). + Error int32 + + // Unique specifies the unique identifier of the corresponding request. + Unique FUSEOpID +} + +// FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h. +// Our taget version is 7.23 but we have few implemented in advance. +const ( + FUSE_ASYNC_READ = 1 << 0 + FUSE_POSIX_LOCKS = 1 << 1 + FUSE_FILE_OPS = 1 << 2 + FUSE_ATOMIC_O_TRUNC = 1 << 3 + FUSE_EXPORT_SUPPORT = 1 << 4 + FUSE_BIG_WRITES = 1 << 5 + FUSE_DONT_MASK = 1 << 6 + FUSE_SPLICE_WRITE = 1 << 7 + FUSE_SPLICE_MOVE = 1 << 8 + FUSE_SPLICE_READ = 1 << 9 + FUSE_FLOCK_LOCKS = 1 << 10 + FUSE_HAS_IOCTL_DIR = 1 << 11 + FUSE_AUTO_INVAL_DATA = 1 << 12 + FUSE_DO_READDIRPLUS = 1 << 13 + FUSE_READDIRPLUS_AUTO = 1 << 14 + FUSE_ASYNC_DIO = 1 << 15 + FUSE_WRITEBACK_CACHE = 1 << 16 + FUSE_NO_OPEN_SUPPORT = 1 << 17 + FUSE_MAX_PAGES = 1 << 22 // From FUSE 7.28 +) + +// currently supported FUSE protocol version numbers. +const ( + FUSE_KERNEL_VERSION = 7 + FUSE_KERNEL_MINOR_VERSION = 31 +) + +// Constants relevant to FUSE operations. +const ( + FUSE_NAME_MAX = 1024 + FUSE_PAGE_SIZE = 4096 + FUSE_DIRENT_ALIGN = 8 +) + +// FUSEInitIn is the request sent by the kernel to the daemon, +// to negotiate the version and flags. +// +// +marshal +type FUSEInitIn struct { + // Major version supported by kernel. + Major uint32 + + // Minor version supported by the kernel. + Minor uint32 + + // MaxReadahead is the maximum number of bytes to read-ahead + // decided by the kernel. + MaxReadahead uint32 + + // Flags of this init request. + Flags uint32 +} + +// FUSEInitOut is the reply sent by the daemon to the kernel +// for FUSEInitIn. We target FUSE 7.23; this struct supports 7.28. +// +// +marshal +type FUSEInitOut struct { + // Major version supported by daemon. + Major uint32 + + // Minor version supported by daemon. + Minor uint32 + + // MaxReadahead is the maximum number of bytes to read-ahead. + // Decided by the daemon, after receiving the value from kernel. + MaxReadahead uint32 + + // Flags of this init reply. + Flags uint32 + + // MaxBackground is the maximum number of pending background requests + // that the daemon wants. + MaxBackground uint16 + + // CongestionThreshold is the daemon-decided threshold for + // the number of the pending background requests. + CongestionThreshold uint16 + + // MaxWrite is the daemon's maximum size of a write buffer. + // Kernel adjusts it to the minimum (fuse/init.go:fuseMinMaxWrite). + // if the value from daemon is too small. + MaxWrite uint32 + + // TimeGran is the daemon's time granularity for mtime and ctime metadata. + // The unit is nanosecond. + // Value should be power of 10. + // 1 indicates full nanosecond granularity support. + TimeGran uint32 + + // MaxPages is the daemon's maximum number of pages for one write operation. + // Kernel adjusts it to the maximum (fuse/init.go:FUSE_MAX_MAX_PAGES). + // if the value from daemon is too large. + MaxPages uint16 + + _ uint16 + + _ [8]uint32 +} + +// FUSE_GETATTR_FH is currently the only flag of FUSEGetAttrIn.GetAttrFlags. +// If it is set, the file handle (FUSEGetAttrIn.Fh) is used to indicate the +// object instead of the node id attribute in the request header. +const FUSE_GETATTR_FH = (1 << 0) + +// FUSEGetAttrIn is the request sent by the kernel to the daemon, +// to get the attribute of a inode. +// +// +marshal +type FUSEGetAttrIn struct { + // GetAttrFlags specifies whether getattr request is sent with a nodeid or + // with a file handle. + GetAttrFlags uint32 + + _ uint32 + + // Fh is the file handler when GetAttrFlags has FUSE_GETATTR_FH bit. If + // used, the operation is analogous to fstat(2). + Fh uint64 +} + +// FUSEAttr is the struct used in the reponse FUSEGetAttrOut. +// +// +marshal +type FUSEAttr struct { + // Ino is the inode number of this file. + Ino uint64 + + // Size is the size of this file. + Size uint64 + + // Blocks is the number of the 512B blocks allocated by this file. + Blocks uint64 + + // Atime is the time of last access. + Atime uint64 + + // Mtime is the time of last modification. + Mtime uint64 + + // Ctime is the time of last status change. + Ctime uint64 + + // AtimeNsec is the nano second part of Atime. + AtimeNsec uint32 + + // MtimeNsec is the nano second part of Mtime. + MtimeNsec uint32 + + // CtimeNsec is the nano second part of Ctime. + CtimeNsec uint32 + + // Mode contains the file type and mode. + Mode uint32 + + // Nlink is the number of the hard links. + Nlink uint32 + + // UID is user ID of the owner. + UID uint32 + + // GID is group ID of the owner. + GID uint32 + + // Rdev is the device ID if this is a special file. + Rdev uint32 + + // BlkSize is the block size for filesystem I/O. + BlkSize uint32 + + _ uint32 +} + +// FUSEGetAttrOut is the reply sent by the daemon to the kernel +// for FUSEGetAttrIn. +// +// +marshal +type FUSEGetAttrOut struct { + // AttrValid and AttrValidNsec describe the attribute cache duration + AttrValid uint64 + + // AttrValidNsec is the nanosecond part of the attribute cache duration + AttrValidNsec uint32 + + _ uint32 + + // Attr contains the metadata returned from the FUSE server + Attr FUSEAttr +} + +// FUSEEntryOut is the reply sent by the daemon to the kernel +// for FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and +// FUSE_LOOKUP. +// +// +marshal +type FUSEEntryOut struct { + // NodeID is the ID for current inode. + NodeID uint64 + + // Generation is the generation number of inode. + // Used to identify an inode that have different ID at different time. + Generation uint64 + + // EntryValid indicates timeout for an entry. + EntryValid uint64 + + // AttrValid indicates timeout for an entry's attributes. + AttrValid uint64 + + // EntryValidNsec indicates timeout for an entry in nanosecond. + EntryValidNSec uint32 + + // AttrValidNsec indicates timeout for an entry's attributes in nanosecond. + AttrValidNSec uint32 + + // Attr contains the attributes of an entry. + Attr FUSEAttr +} + +// FUSELookupIn is the request sent by the kernel to the daemon +// to look up a file name. +// +// Dynamically-sized objects cannot be marshalled. +type FUSELookupIn struct { + marshal.StubMarshallable + + // Name is a file name to be looked up. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSELookupIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSELookupIn. +// 1 extra byte for null-terminated string. +func (r *FUSELookupIn) SizeBytes() int { + return len(r.Name) + 1 +} + +// MAX_NON_LFS indicates the maximum offset without large file support. +const MAX_NON_LFS = ((1 << 31) - 1) + +// flags returned by OPEN request. +const ( + // FOPEN_DIRECT_IO indicates bypassing page cache for this opened file. + FOPEN_DIRECT_IO = 1 << 0 + // FOPEN_KEEP_CACHE avoids invalidate of data cache on open. + FOPEN_KEEP_CACHE = 1 << 1 + // FOPEN_NONSEEKABLE indicates the file cannot be seeked. + FOPEN_NONSEEKABLE = 1 << 2 +) + +// FUSEOpenIn is the request sent by the kernel to the daemon, +// to negotiate flags and get file handle. +// +// +marshal +type FUSEOpenIn struct { + // Flags of this open request. + Flags uint32 + + _ uint32 +} + +// FUSEOpenOut is the reply sent by the daemon to the kernel +// for FUSEOpenIn. +// +// +marshal +type FUSEOpenOut struct { + // Fh is the file handler for opened file. + Fh uint64 + + // OpenFlag for the opened file. + OpenFlag uint32 + + _ uint32 +} + +// FUSE_READ flags, consistent with the ones in include/uapi/linux/fuse.h. +const ( + FUSE_READ_LOCKOWNER = 1 << 1 +) + +// FUSEReadIn is the request sent by the kernel to the daemon +// for FUSE_READ. +// +// +marshal +type FUSEReadIn struct { + // Fh is the file handle in userspace. + Fh uint64 + + // Offset is the read offset. + Offset uint64 + + // Size is the number of bytes to read. + Size uint32 + + // ReadFlags for this FUSE_READ request. + // Currently only contains FUSE_READ_LOCKOWNER. + ReadFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 + + // Flags for the underlying file. + Flags uint32 + + _ uint32 +} + +// FUSEWriteIn is the first part of the payload of the +// request sent by the kernel to the daemon +// for FUSE_WRITE (struct for FUSE version >= 7.9). +// +// The second part of the payload is the +// binary bytes of the data to be written. +// +// +marshal +type FUSEWriteIn struct { + // Fh is the file handle in userspace. + Fh uint64 + + // Offset is the write offset. + Offset uint64 + + // Size is the number of bytes to write. + Size uint32 + + // ReadFlags for this FUSE_WRITE request. + WriteFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 + + // Flags for the underlying file. + Flags uint32 + + _ uint32 +} + +// FUSEWriteOut is the payload of the reply sent by the daemon to the kernel +// for a FUSE_WRITE request. +// +// +marshal +type FUSEWriteOut struct { + // Size is the number of bytes written. + Size uint32 + + _ uint32 +} + +// FUSEReleaseIn is the request sent by the kernel to the daemon +// when there is no more reference to a file. +// +// +marshal +type FUSEReleaseIn struct { + // Fh is the file handler for the file to be released. + Fh uint64 + + // Flags of the file. + Flags uint32 + + // ReleaseFlags of this release request. + ReleaseFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 +} + +// FUSECreateMeta contains all the static fields of FUSECreateIn, +// which is used for FUSE_CREATE. +// +// +marshal +type FUSECreateMeta struct { + // Flags of the creating file. + Flags uint32 + + // Mode is the mode of the creating file. + Mode uint32 + + // Umask is the current file mode creation mask. + Umask uint32 + _ uint32 +} + +// FUSECreateIn contains all the arguments sent by the kernel to the daemon, to +// atomically create and open a new regular file. +// +// Dynamically-sized objects cannot be marshalled. +type FUSECreateIn struct { + marshal.StubMarshallable + + // CreateMeta contains mode, rdev and umash field for FUSE_MKNODS. + CreateMeta FUSECreateMeta + + // Name is the name of the node to create. + Name string +} + +// MarshalBytes serializes r.CreateMeta and r.Name to the dst buffer. +func (r *FUSECreateIn) MarshalBytes(buf []byte) { + r.CreateMeta.MarshalBytes(buf[:r.CreateMeta.SizeBytes()]) + copy(buf[r.CreateMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSECreateIn. +// 1 extra byte for null-terminated string. +func (r *FUSECreateIn) SizeBytes() int { + return r.CreateMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSEMknodMeta contains all the static fields of FUSEMknodIn, +// which is used for FUSE_MKNOD. +// +// +marshal +type FUSEMknodMeta struct { + // Mode of the inode to create. + Mode uint32 + + // Rdev encodes device major and minor information. + Rdev uint32 + + // Umask is the current file mode creation mask. + Umask uint32 + + _ uint32 +} + +// FUSEMknodIn contains all the arguments sent by the kernel +// to the daemon, to create a new file node. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEMknodIn struct { + marshal.StubMarshallable + + // MknodMeta contains mode, rdev and umash field for FUSE_MKNODS. + MknodMeta FUSEMknodMeta + + // Name is the name of the node to create. + Name string +} + +// MarshalBytes serializes r.MknodMeta and r.Name to the dst buffer. +func (r *FUSEMknodIn) MarshalBytes(buf []byte) { + r.MknodMeta.MarshalBytes(buf[:r.MknodMeta.SizeBytes()]) + copy(buf[r.MknodMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEMknodIn. +// 1 extra byte for null-terminated string. +func (r *FUSEMknodIn) SizeBytes() int { + return r.MknodMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSESymLinkIn is the request sent by the kernel to the daemon, +// to create a symbolic link. +// +// Dynamically-sized objects cannot be marshalled. +type FUSESymLinkIn struct { + marshal.StubMarshallable + + // Name of symlink to create. + Name string + + // Target of the symlink. + Target string +} + +// MarshalBytes serializes r.Name and r.Target to the dst buffer. +// Left null-termination at end of r.Name and r.Target. +func (r *FUSESymLinkIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) + copy(buf[len(r.Name)+1:], r.Target) +} + +// SizeBytes is the size of the memory representation of FUSESymLinkIn. +// 2 extra bytes for null-terminated string. +func (r *FUSESymLinkIn) SizeBytes() int { + return len(r.Name) + len(r.Target) + 2 +} + +// FUSEEmptyIn is used by operations without request body. +type FUSEEmptyIn struct{ marshal.StubMarshallable } + +// MarshalBytes do nothing for marshal. +func (r *FUSEEmptyIn) MarshalBytes(buf []byte) {} + +// SizeBytes is 0 for empty request. +func (r *FUSEEmptyIn) SizeBytes() int { + return 0 +} + +// FUSEMkdirMeta contains all the static fields of FUSEMkdirIn, +// which is used for FUSE_MKDIR. +// +// +marshal +type FUSEMkdirMeta struct { + // Mode of the directory of create. + Mode uint32 + + // Umask is the user file creation mask. + Umask uint32 +} + +// FUSEMkdirIn contains all the arguments sent by the kernel +// to the daemon, to create a new directory. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEMkdirIn struct { + marshal.StubMarshallable + + // MkdirMeta contains Mode and Umask of the directory to create. + MkdirMeta FUSEMkdirMeta + + // Name of the directory to create. + Name string +} + +// MarshalBytes serializes r.MkdirMeta and r.Name to the dst buffer. +func (r *FUSEMkdirIn) MarshalBytes(buf []byte) { + r.MkdirMeta.MarshalBytes(buf[:r.MkdirMeta.SizeBytes()]) + copy(buf[r.MkdirMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEMkdirIn. +// 1 extra byte for null-terminated Name string. +func (r *FUSEMkdirIn) SizeBytes() int { + return r.MkdirMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSERmDirIn is the request sent by the kernel to the daemon +// when trying to remove a directory. +// +// Dynamically-sized objects cannot be marshalled. +type FUSERmDirIn struct { + marshal.StubMarshallable + + // Name is a directory name to be removed. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSERmDirIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSERmDirIn. +func (r *FUSERmDirIn) SizeBytes() int { + return len(r.Name) + 1 +} + +// FUSEDirents is a list of Dirents received from the FUSE daemon server. +// It is used for FUSE_READDIR. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEDirents struct { + marshal.StubMarshallable + + Dirents []*FUSEDirent +} + +// FUSEDirent is a Dirent received from the FUSE daemon server. +// It is used for FUSE_READDIR. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEDirent struct { + marshal.StubMarshallable + + // Meta contains all the static fields of FUSEDirent. + Meta FUSEDirentMeta + + // Name is the filename of the dirent. + Name string +} + +// FUSEDirentMeta contains all the static fields of FUSEDirent. +// It is used for FUSE_READDIR. +// +// +marshal +type FUSEDirentMeta struct { + // Inode of the dirent. + Ino uint64 + + // Offset of the dirent. + Off uint64 + + // NameLen is the length of the dirent name. + NameLen uint32 + + // Type of the dirent. + Type uint32 +} + +// SizeBytes is the size of the memory representation of FUSEDirents. +func (r *FUSEDirents) SizeBytes() int { + var sizeBytes int + for _, dirent := range r.Dirents { + sizeBytes += dirent.SizeBytes() + } + + return sizeBytes +} + +// UnmarshalBytes deserializes FUSEDirents from the src buffer. +func (r *FUSEDirents) UnmarshalBytes(src []byte) { + for { + if len(src) <= (*FUSEDirentMeta)(nil).SizeBytes() { + break + } + + // Its unclear how many dirents there are in src. Each dirent is dynamically + // sized and so we can't make assumptions about how many dirents we can allocate. + if r.Dirents == nil { + r.Dirents = make([]*FUSEDirent, 0) + } + + // We have to allocate a struct for each dirent - there must be a better way + // to do this. Linux allocates 1 page to store all the dirents and then + // simply reads them from the page. + var dirent FUSEDirent + dirent.UnmarshalBytes(src) + r.Dirents = append(r.Dirents, &dirent) + + src = src[dirent.SizeBytes():] + } +} + +// SizeBytes is the size of the memory representation of FUSEDirent. +func (r *FUSEDirent) SizeBytes() int { + dataSize := r.Meta.SizeBytes() + len(r.Name) + + // Each Dirent must be padded such that its size is a multiple + // of FUSE_DIRENT_ALIGN. Similar to the fuse dirent alignment + // in linux/fuse.h. + return (dataSize + (FUSE_DIRENT_ALIGN - 1)) & ^(FUSE_DIRENT_ALIGN - 1) +} + +// UnmarshalBytes deserializes FUSEDirent from the src buffer. +func (r *FUSEDirent) UnmarshalBytes(src []byte) { + r.Meta.UnmarshalBytes(src) + src = src[r.Meta.SizeBytes():] + + if r.Meta.NameLen > FUSE_NAME_MAX { + // The name is too long and therefore invalid. We don't + // need to unmarshal the name since it'll be thrown away. + return + } + + buf := make([]byte, r.Meta.NameLen) + name := primitive.ByteSlice(buf) + name.UnmarshalBytes(src[:r.Meta.NameLen]) + r.Name = string(name) +} + +// FATTR_* consts are the attribute flags defined in include/uapi/linux/fuse.h. +// These should be or-ed together for setattr to know what has been changed. +const ( + FATTR_MODE = (1 << 0) + FATTR_UID = (1 << 1) + FATTR_GID = (1 << 2) + FATTR_SIZE = (1 << 3) + FATTR_ATIME = (1 << 4) + FATTR_MTIME = (1 << 5) + FATTR_FH = (1 << 6) + FATTR_ATIME_NOW = (1 << 7) + FATTR_MTIME_NOW = (1 << 8) + FATTR_LOCKOWNER = (1 << 9) + FATTR_CTIME = (1 << 10) +) + +// FUSESetAttrIn is the request sent by the kernel to the daemon, +// to set the attribute(s) of a file. +// +// +marshal +type FUSESetAttrIn struct { + // Valid indicates which attributes are modified by this request. + Valid uint32 + + _ uint32 + + // Fh is used to identify the file if FATTR_FH is set in Valid. + Fh uint64 + + // Size is the size that the request wants to change to. + Size uint64 + + // LockOwner is the owner of the lock that the request wants to change to. + LockOwner uint64 + + // Atime is the access time that the request wants to change to. + Atime uint64 + + // Mtime is the modification time that the request wants to change to. + Mtime uint64 + + // Ctime is the status change time that the request wants to change to. + Ctime uint64 + + // AtimeNsec is the nano second part of Atime. + AtimeNsec uint32 + + // MtimeNsec is the nano second part of Mtime. + MtimeNsec uint32 + + // CtimeNsec is the nano second part of Ctime. + CtimeNsec uint32 + + // Mode is the file mode that the request wants to change to. + Mode uint32 + + _ uint32 + + // UID is the user ID of the owner that the request wants to change to. + UID uint32 + + // GID is the group ID of the owner that the request wants to change to. + GID uint32 + + _ uint32 +} + +// FUSEUnlinkIn is the request sent by the kernel to the daemon +// when trying to unlink a node. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEUnlinkIn struct { + marshal.StubMarshallable + + // Name of the node to unlink. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer, which should +// have size len(r.Name) + 1 and last byte set to 0. +func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEUnlinkIn. +// 1 extra byte for null-terminated Name string. +func (r *FUSEUnlinkIn) SizeBytes() int { + return len(r.Name) + 1 +} diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go index 08bfde3b5..8138088a6 100644 --- a/pkg/abi/linux/futex.go +++ b/pkg/abi/linux/futex.go @@ -60,3 +60,21 @@ const ( FUTEX_WAITERS = 0x80000000 FUTEX_OWNER_DIED = 0x40000000 ) + +// FUTEX_BITSET_MATCH_ANY has all bits set. +const FUTEX_BITSET_MATCH_ANY = 0xffffffff + +// ROBUST_LIST_LIMIT protects against a deliberately circular list. +const ROBUST_LIST_LIMIT = 2048 + +// RobustListHead corresponds to Linux's struct robust_list_head. +// +// +marshal +type RobustListHead struct { + List uint64 + FutexOffset uint64 + ListOpPending uint64 +} + +// SizeOfRobustListHead is the size of a RobustListHead struct. +var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes() diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2062e6a4b..7df02dd6d 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -67,10 +67,29 @@ const ( // ioctl(2) requests provided by uapi/linux/sockios.h const ( - SIOCGIFMEM = 0x891f - SIOCGIFPFLAGS = 0x8935 - SIOCGMIIPHY = 0x8947 - SIOCGMIIREG = 0x8948 + SIOCGIFNAME = 0x8910 + SIOCGIFCONF = 0x8912 + SIOCGIFFLAGS = 0x8913 + SIOCGIFADDR = 0x8915 + SIOCGIFDSTADDR = 0x8917 + SIOCGIFBRDADDR = 0x8919 + SIOCGIFNETMASK = 0x891b + SIOCGIFMETRIC = 0x891d + SIOCGIFMTU = 0x8921 + SIOCGIFMEM = 0x891f + SIOCGIFHWADDR = 0x8927 + SIOCGIFINDEX = 0x8933 + SIOCGIFPFLAGS = 0x8935 + SIOCGIFTXQLEN = 0x8942 + SIOCETHTOOL = 0x8946 + SIOCGMIIPHY = 0x8947 + SIOCGMIIREG = 0x8948 + SIOCGIFMAP = 0x8970 +) + +// ioctl(2) requests provided by uapi/asm-generic/sockios.h +const ( + SIOCGSTAMP = 0x8906 ) // ioctl(2) directions. Used to calculate requests number. @@ -94,7 +113,57 @@ const ( _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS ) +// Constants from uapi/linux/fs.h. +const ( + FS_IOC_GETFLAGS = 2148034049 + FS_VERITY_FL = 1048576 +) + +// Constants from uapi/linux/fsverity.h. +const ( + FS_IOC_ENABLE_VERITY = 1082156677 + FS_IOC_MEASURE_VERITY = 3221513862 +) + +// DigestMetadata is a helper struct for VerityDigest. +// +// +marshal +type DigestMetadata struct { + DigestAlgorithm uint16 + DigestSize uint16 +} + +// SizeOfDigestMetadata is the size of struct DigestMetadata. +const SizeOfDigestMetadata = 4 + +// VerityDigest is struct from uapi/linux/fsverity.h. +type VerityDigest struct { + Metadata DigestMetadata + Digest []byte +} + // IOC outputs the result of _IOC macro in asm-generic/ioctl.h. func IOC(dir, typ, nr, size uint32) uint32 { return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT } + +// Kcov ioctls from kernel/kcov.h. +var ( + KCOV_INIT_TRACE = IOC(_IOC_READ, 'c', 1, 8) + KCOV_ENABLE = IOC(_IOC_NONE, 'c', 100, 0) + KCOV_DISABLE = IOC(_IOC_NONE, 'c', 101, 0) +) + +// Kcov trace types from kernel/kcov.h. +const ( + KCOV_TRACE_PC = 0 + KCOV_TRACE_CMP = 1 +) + +// Kcov state constants from kernel/kcov.h. +const ( + KCOV_MODE_DISABLED = 0 + KCOV_MODE_INIT = 1 + KCOV_MODE_TRACE_PC = 2 + KCOV_MODE_TRACE_CMP = 3 +) diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go index 22acd2d43..c6e65df62 100644 --- a/pkg/abi/linux/ipc.go +++ b/pkg/abi/linux/ipc.go @@ -37,6 +37,8 @@ const IPC_PRIVATE = 0 // features like 32-bit UIDs. // IPCPerm is equivalent to struct ipc64_perm. +// +// +marshal type IPCPerm struct { Key uint32 UID uint32 diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go index 281acdbde..3b4abece1 100644 --- a/pkg/abi/linux/linux.go +++ b/pkg/abi/linux/linux.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux contains the constants and types needed to interface with a Linux kernel. +// Package linux contains the constants and types needed to interface with a +// Linux kernel. package linux // NumSoftIRQ is the number of software IRQs, exposed via /proc/stat. @@ -21,6 +22,8 @@ package linux const NumSoftIRQ = 10 // Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48. +// +// +marshal type Sysinfo struct { Uptime int64 Loads [3]uint64 @@ -34,6 +37,6 @@ type Sysinfo struct { _ [6]byte // Pad Procs to 64bits. TotalHigh uint64 FreeHigh uint64 - Unit uint32 - /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */ + Unit uint32 `marshal:"unaligned"` // Struct ends mid-64-bit-word. + // The _f field in the glibc version of Sysinfo has size 0 on AMD64. } diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 7866352b4..0faf015c7 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -22,6 +22,8 @@ const ( ) // IFReq is an interface request. +// +// +marshal type IFReq struct { // IFName is an encoded name, normally null-terminated. This should be // accessed via the Name and SetName functions. @@ -79,6 +81,8 @@ type IFMap struct { // IFConf is used to return a list of interfaces and their addresses. See // netdevice(7) and struct ifconf for more detail on its use. +// +// +marshal type IFConf struct { Len int32 _ [4]byte // Pad to sizeof(struct ifconf). diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..b521144d9 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/usermem" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -51,7 +59,7 @@ var VerdictStrings = map[int32]string{ NF_RETURN: "RETURN", } -// Socket options. These correspond to values in +// Socket options for SOL_SOCKET. These correspond to values in // include/uapi/linux/netfilter_ipv4/ip_tables.h. const ( IPT_BASE_CTL = 64 @@ -66,6 +74,12 @@ const ( IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET ) +// Socket option for SOL_IP. This corresponds to the value in +// include/uapi/linux/netfilter_ipv4.h. +const ( + SO_ORIGINAL_DST = 80 +) + // Name lengths. These correspond to values in // include/uapi/linux/netfilter/x_tables.h. const ( @@ -76,6 +90,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +128,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +225,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -227,6 +265,18 @@ type KernelXTEntryMatch struct { Data []byte } +// XTGetRevision corresponds to xt_get_revision in +// include/uapi/linux/netfilter/x_tables.h +// +// +marshal +type XTGetRevision struct { + Name ExtensionName + Revision uint8 +} + +// SizeOfXTGetRevision is the size of an XTGetRevision. +const SizeOfXTGetRevision = 30 + // XTEntryTarget holds a target for a rule. For example, it can specify that // packets matching the rule should DROP, ACCEPT, or use an extension target. // iptables-extension(8) has a list of possible targets. @@ -247,6 +297,13 @@ type XTEntryTarget struct { // SizeOfXTEntryTarget is the size of an XTEntryTarget. const SizeOfXTEntryTarget = 32 +// KernelXTEntryTarget is identical to XTEntryTarget, but contains a +// variable-length Data field. +type KernelXTEntryTarget struct { + XTEntryTarget + Data []byte +} + // XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, // RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. @@ -321,6 +378,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +395,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +411,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := range ke.Entrytable { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := range ke.Entrytable { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return cc.CopyOutBytes(addr, ke.marshalAll(cc)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,16 +525,12 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 // ExtensionName holds the name of a netfilter extension. +// +// +marshal type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte // String implements fmt.Stringer. @@ -392,6 +539,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go new file mode 100644 index 000000000..6d31eb5e3 --- /dev/null +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -0,0 +1,336 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "io" + + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/usermem" +) + +// This file contains structures required to support IPv6 netfilter and +// ip6tables. Some constants and structs are equal to their IPv4 analogues, and +// are only distinguished by context (e.g. whether used on an IPv4 of IPv6 +// socket). + +// Socket options for SOL_SOCLET. These correspond to values in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + IP6T_BASE_CTL = 64 + IP6T_SO_SET_REPLACE = IPT_BASE_CTL + IP6T_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1 + IP6T_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS + + IP6T_SO_GET_INFO = IPT_BASE_CTL + IP6T_SO_GET_ENTRIES = IPT_BASE_CTL + 1 + IP6T_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 4 + IP6T_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 5 + IP6T_SO_GET_MAX = IP6T_SO_GET_REVISION_TARGET +) + +// IP6T_ORIGINAL_DST is the ip6tables SOL_IPV6 socket option. Corresponds to +// the value in include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// TODO(gvisor.dev/issue/3549): Support IPv6 original destination. +const IP6T_ORIGINAL_DST = 80 + +// IP6TReplace is the argument for the IP6T_SO_SET_REPLACE sockopt. It +// corresponds to struct ip6t_replace in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// +// +marshal +type IP6TReplace struct { + Name TableName + ValidHooks uint32 + NumEntries uint32 + Size uint32 + HookEntry [NF_INET_NUMHOOKS]uint32 + Underflow [NF_INET_NUMHOOKS]uint32 + NumCounters uint32 + Counters uint64 // This is really a *XTCounters. + // Entries is omitted here because it would cause IP6TReplace to be an + // extra byte longer (see http://www.catb.org/esr/structure-packing/). + // Entries [0]IP6TEntry +} + +// SizeOfIP6TReplace is the size of an IP6TReplace. +const SizeOfIP6TReplace = 96 + +// KernelIP6TGetEntries is identical to IP6TGetEntries, but includes the +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. +type KernelIP6TGetEntries struct { + IPTGetEntries + Entrytable []KernelIP6TEntry +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIP6TGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := range ke.Entrytable { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := range ke.Entrytable { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIP6TGetEntries) Packed() bool { + // KernelIP6TGetEntries isn't packed because the ke.Entrytable contains + // an indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIP6TGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIP6TGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results + // in a partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, + // fall back to MarshalBytes. + return cc.CopyOutBytes(addr, ke.marshalAll(cc)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) +} + +func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIP6TGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil) + +// IP6TEntry is an iptables rule. It corresponds to struct ip6t_entry in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// +// +marshal +type IP6TEntry struct { + // IPv6 is used to filter packets based on the IPv6 header. + IPv6 IP6TIP + + // NFCache relates to kernel-internal caching and isn't used by + // userspace. + NFCache uint32 + + // TargetOffset is the byte offset from the beginning of this IPTEntry + // to the start of the entry's target. + TargetOffset uint16 + + // NextOffset is the byte offset from the beginning of this IPTEntry to + // the start of the next entry. It is thus also the size of the entry. + NextOffset uint16 + + // Comeback is a return pointer. It is not used by userspace. + Comeback uint32 + + _ [4]byte + + // Counters holds the packet and byte counts for this rule. + Counters XTCounters + + // Elems holds the data for all this rule's matches followed by the + // target. It is variable length -- users have to iterate over any + // matches and use TargetOffset and NextOffset to make sense of the + // data. + // + // Elems is omitted here because it would cause IPTEntry to be an extra + // byte larger (see http://www.catb.org/esr/structure-packing/). + // + // Elems [0]byte +} + +// SizeOfIP6TEntry is the size of an IP6TEntry. +const SizeOfIP6TEntry = 168 + +// KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field. +// KernelIP6TEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. +type KernelIP6TEntry struct { + Entry IP6TEntry + + // Elems holds the data for all this rule's matches followed by the + // target. It is variable length -- users have to iterate over any + // matches and use TargetOffset and NextOffset to make sense of the + // data. + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIP6TEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) +} + +// IP6TIP contains information for matching a packet's IP header. +// It corresponds to struct ip6t_ip6 in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +// +// +marshal +type IP6TIP struct { + // Src is the source IP address. + Src Inet6Addr + + // Dst is the destination IP address. + Dst Inet6Addr + + // SrcMask is the source IP mask. + SrcMask Inet6Addr + + // DstMask is the destination IP mask. + DstMask Inet6Addr + + // InputInterface is the input network interface. + InputInterface [IFNAMSIZ]byte + + // OutputInterface is the output network interface. + OutputInterface [IFNAMSIZ]byte + + // InputInterfaceMask is the input interface mask. + InputInterfaceMask [IFNAMSIZ]byte + + // OuputInterfaceMask is the output interface mask. + OutputInterfaceMask [IFNAMSIZ]byte + + // Protocol is the transport protocol. + Protocol uint16 + + // TOS matches TOS flags when Flags indicates filtering by TOS. + TOS uint8 + + // Flags define matching behavior for the IP header. + Flags uint8 + + // InverseFlags invert the meaning of fields in struct IPTIP. See the + // IP6T_INV_* flags. + InverseFlags uint8 + + // Linux defines in6_addr (Inet6Addr for us) as the union of a + // 16-element byte array and a 4-element 32-bit integer array, so the + // whole struct is 4-byte aligned. + _ [3]byte +} + +const SizeOfIP6TIP = 136 + +// Flags in IP6TIP.Flags. Corresponding constants are in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + // Whether to check the Protocol field. + IP6T_F_PROTO = 0x01 + // Whether to match the TOS field. + IP6T_F_TOS = 0x02 + // Indicates that the jump target is an aboslute GOTO, not an offset. + IP6T_F_GOTO = 0x04 + // Enables all flags. + IP6T_F_MASK = 0x07 +) + +// Flags in IP6TIP.InverseFlags. Corresponding constants are in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + // Invert the meaning of InputInterface. + IP6T_INV_VIA_IN = 0x01 + // Invert the meaning of OutputInterface. + IP6T_INV_VIA_OUT = 0x02 + // Invert the meaning of TOS. + IP6T_INV_TOS = 0x04 + // Invert the meaning of Src. + IP6T_INV_SRCIP = 0x08 + // Invert the meaning of Dst. + IP6T_INV_DSTIP = 0x10 + // Invert the meaning of the IPT_F_FRAG flag. + IP6T_INV_FRAG = 0x20 + // Enable all flags. + IP6T_INV_MASK = 0x7F +) + +// NFNATRange corresponds to struct nf_nat_range in +// include/uapi/linux/netfilter/nf_nat.h. +type NFNATRange struct { + Flags uint32 + MinAddr Inet6Addr + MaxAddr Inet6Addr + MinProto uint16 // Network byte order. + MaxProto uint16 // Network byte order. +} + +// SizeOfNFNATRange is the size of NFNATRange. +const SizeOfNFNATRange = 40 diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index 565dd550e..bf73271c6 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -36,6 +36,9 @@ func TestSizes(t *testing.T) { {XTEntryTarget{}, SizeOfXTEntryTarget}, {XTErrorTarget{}, SizeOfXTErrorTarget}, {XTStandardTarget{}, SizeOfXTStandardTarget}, + {IP6TReplace{}, SizeOfIP6TReplace}, + {IP6TEntry{}, SizeOfIP6TEntry}, + {IP6TIP{}, SizeOfIP6TIP}, } for _, tc := range testCases { diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index 0ba086c76..b41f94a69 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -40,6 +40,8 @@ const ( ) // SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h. +// +// +marshal type SockAddrNetlink struct { Family uint16 _ uint16 diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 40bec566c..ceda0a8d3 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -187,6 +187,8 @@ const ( // Device types, from uapi/linux/if_arp.h. const ( + ARPHRD_NONE = 65534 + ARPHRD_ETHER = 1 ARPHRD_LOOPBACK = 772 ) diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go index c04d26e4c..3443a5768 100644 --- a/pkg/abi/linux/poll.go +++ b/pkg/abi/linux/poll.go @@ -15,6 +15,8 @@ package linux // PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h. +// +// +marshal slice:PollFDSlice type PollFD struct { FD int32 Events int16 diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go index d8302dc85..e29d0ac7e 100644 --- a/pkg/abi/linux/rusage.go +++ b/pkg/abi/linux/rusage.go @@ -26,6 +26,8 @@ const ( ) // Rusage represents the Linux struct rusage. +// +// +marshal type Rusage struct { UTime Timeval STime Timeval diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go index d0607e256..5be3f10f9 100644 --- a/pkg/abi/linux/seccomp.go +++ b/pkg/abi/linux/seccomp.go @@ -34,11 +34,11 @@ type BPFAction uint32 const ( SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000 - SECCOMP_RET_KILL_THREAD = 0x00000000 - SECCOMP_RET_TRAP = 0x00030000 - SECCOMP_RET_ERRNO = 0x00050000 - SECCOMP_RET_TRACE = 0x7ff00000 - SECCOMP_RET_ALLOW = 0x7fff0000 + SECCOMP_RET_KILL_THREAD BPFAction = 0x00000000 + SECCOMP_RET_TRAP BPFAction = 0x00030000 + SECCOMP_RET_ERRNO BPFAction = 0x00050000 + SECCOMP_RET_TRACE BPFAction = 0x7ff00000 + SECCOMP_RET_ALLOW BPFAction = 0x7fff0000 ) func (a BPFAction) String() string { @@ -64,9 +64,41 @@ func (a BPFAction) Data() uint16 { return uint16(a & SECCOMP_RET_DATA) } +// WithReturnCode sets the lower 16 bits of the SECCOMP_RET_ERRNO or +// SECCOMP_RET_TRACE actions to the provided return code, overwriting the previous +// action, and returns a new BPFAction. If not SECCOMP_RET_ERRNO or +// SECCOMP_RET_TRACE then this panics. +func (a BPFAction) WithReturnCode(code uint16) BPFAction { + // mask out the previous return value + baseAction := a & SECCOMP_RET_ACTION_FULL + if baseAction == SECCOMP_RET_ERRNO || baseAction == SECCOMP_RET_TRACE { + return BPFAction(uint32(baseAction) | uint32(code)) + } + panic("WithReturnCode only valid for SECCOMP_RET_ERRNO and SECCOMP_RET_TRACE") +} + // SockFprog is sock_fprog taken from <linux/filter.h>. type SockFprog struct { Len uint16 pad [6]byte Filter *BPFInstruction } + +// SeccompData is equivalent to struct seccomp_data, which contains the data +// passed to seccomp-bpf filters. +// +// +marshal +type SeccompData struct { + // Nr is the system call number. + Nr int32 + + // Arch is an AUDIT_ARCH_* value indicating the system call convention. + Arch uint32 + + // InstructionPointer is the value of the instruction pointer at the time + // of the system call. + InstructionPointer uint64 + + // Args contains the first 6 system call arguments. + Args [6]uint64 +} diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index de422c519..487a626cc 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -35,6 +35,8 @@ const ( const SEM_UNDO = 0x1000 // SemidDS is equivalent to struct semid64_ds. +// +// +marshal type SemidDS struct { SemPerm IPCPerm SemOTime TimeT @@ -45,6 +47,8 @@ type SemidDS struct { } // Sembuf is equivalent to struct sembuf. +// +// +marshal slice:SembufSlice type Sembuf struct { SemNum uint16 SemOp int16 diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go index e45aadb10..274b1e847 100644 --- a/pkg/abi/linux/shm.go +++ b/pkg/abi/linux/shm.go @@ -51,6 +51,8 @@ const ( // ShmidDS is equivalent to struct shmid64_ds. Source: // include/uapi/asm-generic/shmbuf.h +// +// +marshal type ShmidDS struct { ShmPerm IPCPerm ShmSegsz uint64 @@ -66,6 +68,8 @@ type ShmidDS struct { } // ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h +// +// +marshal type ShmParams struct { ShmMax uint64 ShmMin uint64 @@ -75,6 +79,8 @@ type ShmParams struct { } // ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h +// +// +marshal type ShmInfo struct { UsedIDs int32 // Number of currently existing segments. _ [4]byte diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index 1c330e763..6ca57ffbb 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -214,6 +214,8 @@ const ( ) // Sigevent represents struct sigevent. +// +// +marshal type Sigevent struct { Value uint64 // union sigval {int, void*} Signo int32 diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go index 85fad9956..468c6a387 100644 --- a/pkg/abi/linux/signalfd.go +++ b/pkg/abi/linux/signalfd.go @@ -23,6 +23,8 @@ const ( ) // SignalfdSiginfo is the siginfo encoding for signalfds. +// +// +marshal type SignalfdSiginfo struct { Signo uint32 Errno int32 @@ -41,5 +43,5 @@ type SignalfdSiginfo struct { STime uint64 Addr uint64 AddrLSB uint16 - _ [48]uint8 + _ [48]uint8 `marshal:"unaligned"` } diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 4a14ef691..d156d41e4 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -14,7 +14,10 @@ package linux -import "gvisor.dev/gvisor/pkg/binary" +import ( + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" +) // Address families, from linux/socket.h. const ( @@ -83,7 +86,6 @@ const ( MSG_MORE = 0x8000 MSG_WAITFORONE = 0x10000 MSG_SENDPAGE_NOTLAST = 0x20000 - MSG_REINJECT = 0x8000000 MSG_ZEROCOPY = 0x4000000 MSG_FASTOPEN = 0x20000000 MSG_CMSG_CLOEXEC = 0x40000000 @@ -134,6 +136,15 @@ const ( SHUT_RDWR = 2 ) +// Packet types from <linux/if_packet.h> +const ( + PACKET_HOST = 0 // To us + PACKET_BROADCAST = 1 // To all + PACKET_MULTICAST = 2 // To group + PACKET_OTHERHOST = 3 // To someone else + PACKET_OUTGOING = 4 // Outgoing of any type +) + // Socket options from socket.h. const ( SO_DEBUG = 1 @@ -225,14 +236,18 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. +// +// +marshal type SockAddrInet struct { Family uint16 Port uint16 Addr InetAddr - Zero [8]uint8 // pad to sizeof(struct sockaddr). + _ [8]uint8 // pad to sizeof(struct sockaddr). } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. @@ -247,7 +262,14 @@ type InetMulticastRequestWithNIC struct { InterfaceIndex int32 } +// Inet6Addr is struct in6_addr, from uapi/linux/in6.h. +// +// +marshal +type Inet6Addr [16]byte + // SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h. +// +// +marshal type SockAddrInet6 struct { Family uint16 Port uint16 @@ -257,6 +279,8 @@ type SockAddrInet6 struct { } // SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +// +// +marshal type SockAddrLink struct { Family uint16 Protocol uint16 @@ -273,6 +297,8 @@ type SockAddrLink struct { const UnixPathMax = 108 // SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h. +// +// +marshal type SockAddrUnix struct { Family uint16 Path [UnixPathMax]int8 @@ -282,6 +308,8 @@ type SockAddrUnix struct { // equivalent to struct sockaddr. SockAddr ensures that a well-defined set of // types can be used as socket addresses. type SockAddr interface { + marshal.Marshallable + // implementsSockAddr exists purely to allow a type to indicate that they // implement this interface. This method is a no-op and shouldn't be called. implementsSockAddr() @@ -294,6 +322,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -308,6 +338,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -405,6 +437,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go index e6860ed49..206f5af7e 100644 --- a/pkg/abi/linux/time.go +++ b/pkg/abi/linux/time.go @@ -93,6 +93,8 @@ const ( const maxSecInDuration = math.MaxInt64 / int64(time.Second) // TimeT represents time_t in <time.h>. It represents time in seconds. +// +// +marshal type TimeT int64 // NsecToTimeT translates nanoseconds to TimeT (seconds). @@ -102,7 +104,7 @@ func NsecToTimeT(nsec int64) TimeT { // Timespec represents struct timespec in <time.h>. // -// +marshal +// +marshal slice:TimespecSlice type Timespec struct { Sec int64 Nsec int64 @@ -158,7 +160,7 @@ const SizeOfTimeval = 16 // Timeval represents struct timeval in <time.h>. // -// +marshal +// +marshal slice:TimevalSlice type Timeval struct { Sec int64 Usec int64 @@ -196,6 +198,8 @@ func DurationToTimeval(dur time.Duration) Timeval { } // Itimerspec represents struct itimerspec in <time.h>. +// +// +marshal type Itimerspec struct { Interval Timespec Value Timespec @@ -206,12 +210,16 @@ type Itimerspec struct { // struct timeval it_interval; /* next value */ // struct timeval it_value; /* current value */ // }; +// +// +marshal type ItimerVal struct { Interval Timeval Value Timeval } // ClockT represents type clock_t. +// +// +marshal type ClockT int64 // ClockTFromDuration converts time.Duration to clock_t. @@ -220,6 +228,8 @@ func ClockTFromDuration(d time.Duration) ClockT { } // Tms represents struct tms, used by times(2). +// +// +marshal type Tms struct { UTime ClockT STime ClockT @@ -229,6 +239,8 @@ type Tms struct { // TimerID represents type timer_t, which identifies a POSIX per-process // interval timer. +// +// +marshal type TimerID int32 // StatxTimestamp represents struct statx_timestamp. diff --git a/pkg/abi/linux/tty.go b/pkg/abi/linux/tty.go index 8ac02aee8..47e65d9fb 100644 --- a/pkg/abi/linux/tty.go +++ b/pkg/abi/linux/tty.go @@ -23,6 +23,8 @@ const ( ) // Winsize is struct winsize, defined in uapi/asm-generic/termios.h. +// +// +marshal type Winsize struct { Row uint16 Col uint16 @@ -31,6 +33,8 @@ type Winsize struct { } // Termios is struct termios, defined in uapi/asm-generic/termbits.h. +// +// +marshal type Termios struct { InputFlags uint32 OutputFlags uint32 @@ -321,9 +325,9 @@ var MasterTermios = KernelTermios{ OutputSpeed: 38400, } -// DefaultSlaveTermios is the default terminal configuration of the slave end -// of a Unix98 pseudoterminal. -var DefaultSlaveTermios = KernelTermios{ +// DefaultReplicaTermios is the default terminal configuration of the replica +// end of a Unix98 pseudoterminal. +var DefaultReplicaTermios = KernelTermios{ InputFlags: ICRNL | IXON, OutputFlags: OPOST | ONLCR, ControlFlags: B38400 | CS8 | CREAD, @@ -337,6 +341,7 @@ var DefaultSlaveTermios = KernelTermios{ // include/uapi/asm-generic/termios.h. // // +stateify savable +// +marshal type WindowSize struct { Rows uint16 Cols uint16 diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go index 60f220a67..cb7c95437 100644 --- a/pkg/abi/linux/utsname.go +++ b/pkg/abi/linux/utsname.go @@ -26,6 +26,8 @@ const ( ) // UtsName represents struct utsname, the struct returned by uname(2). +// +// +marshal type UtsName struct { Sysname [UTSLen + 1]byte Nodename [UTSLen + 1]byte diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index 99180b208..8ef837f27 100644 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go @@ -23,6 +23,9 @@ const ( XATTR_CREATE = 1 XATTR_REPLACE = 2 + XATTR_TRUSTED_PREFIX = "trusted." + XATTR_TRUSTED_PREFIX_LEN = len(XATTR_TRUSTED_PREFIX) + XATTR_USER_PREFIX = "user." XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX) ) diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index ffc918846..bd3a5cce9 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -6,7 +6,10 @@ go_library( name = "amutex", srcs = ["amutex.go"], visibility = ["//:sandbox"], - deps = ["//pkg/syserror"], + deps = [ + "//pkg/context", + "//pkg/syserror", + ], ) go_test( diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go index a078a31db..d7acc1d9f 100644 --- a/pkg/amutex/amutex.go +++ b/pkg/amutex/amutex.go @@ -19,41 +19,17 @@ package amutex import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) // Sleeper must be implemented by users of the abortable mutex to allow for // cancellation of waits. -type Sleeper interface { - // SleepStart is called by the AbortableMutex.Lock() function when the - // mutex is contended and the goroutine is about to sleep. - // - // A channel can be returned that causes the sleep to be canceled if - // it's readable. If no cancellation is desired, nil can be returned. - SleepStart() <-chan struct{} - - // SleepFinish is called by AbortableMutex.Lock() once a contended mutex - // is acquired or the wait is aborted. - SleepFinish(success bool) - - // Interrupted returns true if the wait is aborted. - Interrupted() bool -} +type Sleeper = context.ChannelSleeper // NoopSleeper is a stateless no-op implementation of Sleeper for anonymous // embedding in other types that do not support cancelation. -type NoopSleeper struct{} - -// SleepStart implements Sleeper.SleepStart. -func (NoopSleeper) SleepStart() <-chan struct{} { - return nil -} - -// SleepFinish implements Sleeper.SleepFinish. -func (NoopSleeper) SleepFinish(success bool) {} - -// Interrupted implements Sleeper.Interrupted. -func (NoopSleeper) Interrupted() bool { return false } +type NoopSleeper = context.Context // Block blocks until either receiving from ch succeeds (in which case it // returns nil) or sleeper is interrupted (in which case it returns diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index c8ee0c3b1..069d0395d 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -21,10 +21,15 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// DecodeProgram translates an array of BPF instructions into text format. -func DecodeProgram(program []linux.BPFInstruction) (string, error) { +// DecodeProgram translates a compiled BPF program into text format. +func DecodeProgram(p Program) (string, error) { + return DecodeInstructions(p.instructions) +} + +// DecodeInstructions translates an array of BPF instructions into text format. +func DecodeInstructions(instns []linux.BPFInstruction) (string, error) { var ret bytes.Buffer - for line, s := range program { + for line, s := range instns { ret.WriteString(fmt.Sprintf("%v: ", line)) if err := decode(s, line, &ret); err != nil { return "", err @@ -34,7 +39,7 @@ func DecodeProgram(program []linux.BPFInstruction) (string, error) { return ret.String(), nil } -// Decode translates BPF instruction into text format. +// Decode translates a single BPF instruction into text format. func Decode(inst linux.BPFInstruction) (string, error) { var ret bytes.Buffer err := decode(inst, -1, &ret) diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go index 6a023f0c0..bb971ce21 100644 --- a/pkg/bpf/decoder_test.go +++ b/pkg/bpf/decoder_test.go @@ -93,7 +93,7 @@ func TestDecode(t *testing.T) { } } -func TestDecodeProgram(t *testing.T) { +func TestDecodeInstructions(t *testing.T) { for _, test := range []struct { name string program []linux.BPFInstruction @@ -126,7 +126,7 @@ func TestDecodeProgram(t *testing.T) { program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)}, fail: true}, } { - got, err := DecodeProgram(test.program) + got, err := DecodeInstructions(test.program) if test.fail { if err == nil { t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got) diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go index 7992044d0..caaf99c83 100644 --- a/pkg/bpf/program_builder.go +++ b/pkg/bpf/program_builder.go @@ -32,13 +32,21 @@ type ProgramBuilder struct { // Maps label names to label objects. labels map[string]*label + // unusableLabels are labels that are added before being referenced in a + // jump. Any labels added this way cannot be referenced later in order to + // avoid backwards references. + unusableLabels map[string]bool + // Array of BPF instructions that makes up the program. instructions []linux.BPFInstruction } // NewProgramBuilder creates a new ProgramBuilder instance. func NewProgramBuilder() *ProgramBuilder { - return &ProgramBuilder{labels: map[string]*label{}} + return &ProgramBuilder{ + labels: map[string]*label{}, + unusableLabels: map[string]bool{}, + } } // label contains information to resolve a label to an offset. @@ -108,9 +116,12 @@ func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel s func (b *ProgramBuilder) AddLabel(name string) error { l, ok := b.labels[name] if !ok { - // This is done to catch jump backwards cases, but it's not strictly wrong - // to have unused labels. - return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name) + if _, ok = b.unusableLabels[name]; ok { + return fmt.Errorf("label %q already set", name) + } + // Mark the label as unusable. This is done to catch backwards jumps. + b.unusableLabels[name] = true + return nil } if l.target != -1 { return fmt.Errorf("label %q target already set: %v", name, l.target) @@ -141,6 +152,10 @@ func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) { func (b *ProgramBuilder) resolveLabels() error { for key, v := range b.labels { + if _, ok := b.unusableLabels[key]; ok { + return fmt.Errorf("backwards reference detected for label: %q", key) + } + if v.target == -1 { return fmt.Errorf("label target not set: %v", key) } diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go index 92ca5f4c3..37f684f25 100644 --- a/pkg/bpf/program_builder_test.go +++ b/pkg/bpf/program_builder_test.go @@ -26,16 +26,16 @@ func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error { if err != nil { return fmt.Errorf("Instructions() failed: %v", err) } - got, err := DecodeProgram(instructions) + got, err := DecodeInstructions(instructions) if err != nil { - return fmt.Errorf("DecodeProgram('instructions') failed: %v", err) + return fmt.Errorf("DecodeInstructions('instructions') failed: %v", err) } - expectedDecoded, err := DecodeProgram(expected) + expectedDecoded, err := DecodeInstructions(expected) if err != nil { - return fmt.Errorf("DecodeProgram('expected') failed: %v", err) + return fmt.Errorf("DecodeInstructions('expected') failed: %v", err) } if got != expectedDecoded { - return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got) + return fmt.Errorf("DecodeInstructions() failed, expected: %q, got: %q", expectedDecoded, got) } return nil } @@ -124,10 +124,38 @@ func TestProgramBuilderLabelWithNoInstruction(t *testing.T) { } } +// TestProgramBuilderUnusedLabel tests that adding an unused label doesn't +// cause program generation to fail. func TestProgramBuilderUnusedLabel(t *testing.T) { p := NewProgramBuilder() - if err := p.AddLabel("unused"); err == nil { - t.Errorf("AddLabel(unused) should have failed") + p.AddStmt(Ld+Abs+W, 10) + p.AddJump(Jmp+Ja, 10, 0, 0) + + expected := []linux.BPFInstruction{ + Stmt(Ld+Abs+W, 10), + Jump(Jmp+Ja, 10, 0, 0), + } + + if err := p.AddLabel("unused"); err != nil { + t.Errorf("AddLabel(unused) should have succeeded") + } + + if err := validate(p, expected); err != nil { + t.Errorf("Validate() failed: %v", err) + } +} + +// TestProgramBuilderBackwardsReference tests that including a backwards +// reference to a label in a program causes a failure. +func TestProgramBuilderBackwardsReference(t *testing.T) { + p := NewProgramBuilder() + if err := p.AddLabel("bw_label"); err != nil { + t.Errorf("failed to add label") + } + p.AddStmt(Ld+Abs+W, 10) + p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "bw_label", 0) + if _, err := p.Instructions(); err == nil { + t.Errorf("Instructions() should have failed") } } diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index dcd086298..1186f788e 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -20,14 +20,17 @@ go_library( srcs = [ "buffer.go", "buffer_list.go", + "pool.go", "safemem.go", "view.go", "view_unsafe.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/context", "//pkg/log", "//pkg/safemem", + "//pkg/usermem", ], ) @@ -35,9 +38,13 @@ go_test( name = "buffer_test", size = "small", srcs = [ + "pool_test.go", "safemem_test.go", "view_test.go", ], library = ":buffer", - deps = ["//pkg/safemem"], + deps = [ + "//pkg/safemem", + "//pkg/state", + ], ) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index c6d089fd9..311808ae9 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -14,36 +14,26 @@ // Package buffer provides the implementation of a buffer view. // -// A view is an flexible buffer, backed by a pool, supporting the safecopy -// operations natively as well as the ability to grow via either prepend or -// append, as well as shrink. +// A view is an flexible buffer, supporting the safecopy operations natively as +// well as the ability to grow via either prepend or append, as well as shrink. package buffer -import ( - "sync" -) - -const bufferSize = 8144 // See below. - // buffer encapsulates a queueable byte buffer. // -// Note that the total size is slightly less than two pages. This is done -// intentionally to ensure that the buffer object aligns with runtime -// internals. We have no hard size or alignment requirements. This two page -// size will effectively minimize internal fragmentation, but still have a -// large enough chunk to limit excessive segmentation. -// // +stateify savable type buffer struct { - data [bufferSize]byte + data []byte read int write int bufferEntry } -// reset resets internal data. -// -// This must be called before returning the buffer to the pool. +// init performs in-place initialization for zero value. +func (b *buffer) init(size int) { + b.data = make([]byte, size) +} + +// Reset resets read and write locations, effectively emptying the buffer. func (b *buffer) Reset() { b.read = 0 b.write = 0 @@ -85,10 +75,3 @@ func (b *buffer) WriteMove(n int) { func (b *buffer) WriteSlice() []byte { return b.data[b.write:] } - -// bufferPool is a pool for buffers. -var bufferPool = sync.Pool{ - New: func() interface{} { - return new(buffer) - }, -} diff --git a/pkg/buffer/pool.go b/pkg/buffer/pool.go new file mode 100644 index 000000000..7ad6132ab --- /dev/null +++ b/pkg/buffer/pool.go @@ -0,0 +1,83 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +const ( + // embeddedCount is the number of buffer structures embedded in the pool. It + // is also the number for overflow allocations. + embeddedCount = 8 + + // defaultBufferSize is the default size for each underlying storage buffer. + // + // It is slightly less than two pages. This is done intentionally to ensure + // that the buffer object aligns with runtime internals. This two page size + // will effectively minimize internal fragmentation, but still have a large + // enough chunk to limit excessive segmentation. + defaultBufferSize = 8144 +) + +// pool allocates buffer. +// +// It contains an embedded buffer storage for fast path when the number of +// buffers needed is small. +// +// +stateify savable +type pool struct { + bufferSize int + avail []buffer `state:"nosave"` + embeddedStorage [embeddedCount]buffer `state:"wait"` +} + +// get gets a new buffer from p. +func (p *pool) get() *buffer { + if p.avail == nil { + p.avail = p.embeddedStorage[:] + } + if len(p.avail) == 0 { + p.avail = make([]buffer, embeddedCount) + } + if p.bufferSize <= 0 { + p.bufferSize = defaultBufferSize + } + buf := &p.avail[0] + buf.init(p.bufferSize) + p.avail = p.avail[1:] + return buf +} + +// put releases buf. +func (p *pool) put(buf *buffer) { + // Remove reference to the underlying storage, allowing it to be garbage + // collected. + buf.data = nil +} + +// setBufferSize sets the size of underlying storage buffer for future +// allocations. It can be called at any time. +func (p *pool) setBufferSize(size int) { + p.bufferSize = size +} + +// afterLoad is invoked by stateify. +func (p *pool) afterLoad() { + // S/R does not save subslice into embeddedStorage correctly. Restore + // available portion of embeddedStorage manually. Restore as nil if none used. + for i := len(p.embeddedStorage); i > 0; i-- { + if p.embeddedStorage[i-1].data != nil { + p.avail = p.embeddedStorage[i:] + break + } + } +} diff --git a/pkg/buffer/pool_test.go b/pkg/buffer/pool_test.go new file mode 100644 index 000000000..8584bac89 --- /dev/null +++ b/pkg/buffer/pool_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "testing" +) + +func TestGetDefaultBufferSize(t *testing.T) { + var p pool + for i := 0; i < embeddedCount*2; i++ { + buf := p.get() + if got, want := len(buf.data), defaultBufferSize; got != want { + t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want) + } + } +} + +func TestGetCustomBufferSize(t *testing.T) { + const size = 100 + + var p pool + p.setBufferSize(size) + for i := 0; i < embeddedCount*2; i++ { + buf := p.get() + if got, want := len(buf.data), size; got != want { + t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want) + } + } +} + +func TestPut(t *testing.T) { + var p pool + buf := p.get() + p.put(buf) + if buf.data != nil { + t.Errorf("buf.data = %x, want nil", buf.data) + } +} diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index b789e56e9..8b42575b4 100644 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go @@ -44,7 +44,7 @@ func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, e // Need at least one buffer. firstBuf := v.data.Back() if firstBuf == nil { - firstBuf = bufferPool.Get().(*buffer) + firstBuf = v.pool.get() v.data.PushBack(firstBuf) } @@ -56,7 +56,7 @@ func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, e count -= l blocks = append(blocks, firstBuf.WriteBlock()) for count > 0 { - emptyBuf := bufferPool.Get().(*buffer) + emptyBuf := v.pool.get() v.data.PushBack(emptyBuf) block := emptyBuf.WriteBlock().TakeFirst64(count) count -= uint64(block.Len()) diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go index 47f357e0c..721cc5934 100644 --- a/pkg/buffer/safemem_test.go +++ b/pkg/buffer/safemem_test.go @@ -23,6 +23,8 @@ import ( ) func TestSafemem(t *testing.T) { + const bufferSize = defaultBufferSize + testCases := []struct { name string input string diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index e6901eadb..00652d675 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -27,6 +27,7 @@ import ( type View struct { data bufferList size int64 + pool pool } // TrimFront removes the first count bytes from the buffer. @@ -81,7 +82,7 @@ func (v *View) advanceRead(count int64) { buf = buf.Next() // Iterate. v.data.Remove(oldBuf) oldBuf.Reset() - bufferPool.Put(oldBuf) + v.pool.put(oldBuf) // Update counts. count -= sz @@ -118,7 +119,7 @@ func (v *View) Truncate(length int64) { // Drop the buffer completely; see above. v.data.Remove(buf) buf.Reset() - bufferPool.Put(buf) + v.pool.put(buf) v.size -= sz } } @@ -137,7 +138,7 @@ func (v *View) Grow(length int64, zero bool) { // Is there some space in the last buffer? if buf == nil || buf.Full() { - buf = bufferPool.Get().(*buffer) + buf = v.pool.get() v.data.PushBack(buf) } @@ -181,7 +182,7 @@ func (v *View) Prepend(data []byte) { for len(data) > 0 { // Do we need an empty buffer? - buf := bufferPool.Get().(*buffer) + buf := v.pool.get() v.data.PushFront(buf) // The buffer is empty; copy last chunk. @@ -211,7 +212,7 @@ func (v *View) Append(data []byte) { // Ensure there's a buffer with space. if buf == nil || buf.Full() { - buf = bufferPool.Get().(*buffer) + buf = v.pool.get() v.data.PushBack(buf) } @@ -297,7 +298,7 @@ func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { // Ensure we have an empty buffer. if buf == nil || buf.Full() { - buf = bufferPool.Get().(*buffer) + buf = v.pool.get() v.data.PushBack(buf) } diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 3db1bc6ee..839af0223 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -16,11 +16,16 @@ package buffer import ( "bytes" + "context" "io" "strings" "testing" + + "gvisor.dev/gvisor/pkg/state" ) +const bufferSize = defaultBufferSize + func fillAppend(v *View, data []byte) { v.Append(data) } @@ -50,6 +55,30 @@ var fillFuncs = map[string]func(*View, []byte){ "writeFromReaderEnd": fillWriteFromReaderEnd, } +func BenchmarkReadAt(b *testing.B) { + b.ReportAllocs() + var v View + v.Append(make([]byte, 100)) + + buf := make([]byte, 10) + for i := 0; i < b.N; i++ { + v.ReadAt(buf, 0) + } +} + +func BenchmarkWriteRead(b *testing.B) { + b.ReportAllocs() + var v View + sz := 1000 + wbuf := make([]byte, sz) + rbuf := bytes.NewBuffer(make([]byte, sz)) + for i := 0; i < b.N; i++ { + v.Append(wbuf) + rbuf.Reset() + v.ReadToWriter(rbuf, int64(sz)) + } +} + func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) { t.Helper() d := make([]byte, n) @@ -465,3 +494,51 @@ func TestView(t *testing.T) { } } } + +func doSaveAndLoad(t *testing.T, toSave, toLoad *View) { + t.Helper() + var buf bytes.Buffer + ctx := context.Background() + if _, err := state.Save(ctx, &buf, toSave); err != nil { + t.Fatal("state.Save:", err) + } + if _, err := state.Load(ctx, bytes.NewReader(buf.Bytes()), toLoad); err != nil { + t.Fatal("state.Load:", err) + } +} + +func TestSaveRestoreViewEmpty(t *testing.T) { + var toSave View + var v View + doSaveAndLoad(t, &toSave, &v) + + if got := v.pool.avail; got != nil { + t.Errorf("pool is not in zero state: v.pool.avail = %v, want nil", got) + } + if got := v.Flatten(); len(got) != 0 { + t.Errorf("v.Flatten() = %x, want []", got) + } +} + +func TestSaveRestoreView(t *testing.T) { + // Create data that fits 2.5 slots. + data := bytes.Join([][]byte{ + bytes.Repeat([]byte{1, 2}, defaultBufferSize), + bytes.Repeat([]byte{3}, defaultBufferSize/2), + }, nil) + + var toSave View + toSave.Append(data) + + var v View + doSaveAndLoad(t, &toSave, &v) + + // Next available slot at index 3; 0-2 slot are used. + i := 3 + if got, want := &v.pool.avail[0], &v.pool.embeddedStorage[i]; got != want { + t.Errorf("next available buffer points to %p, want %p (&v.pool.embeddedStorage[%d])", got, want, i) + } + if got := v.Flatten(); !bytes.Equal(got, data) { + t.Errorf("v.Flatten() = %x, want %x", got, data) + } +} diff --git a/pkg/context/BUILD b/pkg/context/BUILD index 239f31149..f33e23bf7 100644 --- a/pkg/context/BUILD +++ b/pkg/context/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["context.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/amutex", "//pkg/log", ], ) diff --git a/pkg/context/context.go b/pkg/context/context.go index 5319b6d8d..2613bc752 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -26,7 +26,6 @@ import ( "context" "time" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -68,9 +67,10 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { // In both cases, values extracted from the Context should be used instead. type Context interface { log.Logger - amutex.Sleeper context.Context + ChannelSleeper + // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate // is true and the Context represents a Task, the Task's AddressSpace is @@ -85,29 +85,60 @@ type Context interface { UninterruptibleSleepFinish(activate bool) } -// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep -// methods for anonymous embedding in other types that do not implement sleeps. -type NoopSleeper struct { - amutex.NoopSleeper +// A ChannelSleeper represents a goroutine that may sleep interruptibly, where +// interruption is indicated by a channel becoming readable. +type ChannelSleeper interface { + // SleepStart is called before going to sleep interruptibly. If SleepStart + // returns a non-nil channel and that channel becomes ready for receiving + // while the goroutine is sleeping, the goroutine should be woken, and + // SleepFinish(false) should be called. Otherwise, SleepFinish(true) should + // be called after the goroutine stops sleeping. + SleepStart() <-chan struct{} + + // SleepFinish is called after an interruptibly-sleeping goroutine stops + // sleeping, as documented by SleepStart. + SleepFinish(success bool) + + // Interrupted returns true if the channel returned by SleepStart is + // ready for receiving. + Interrupted() bool +} + +// NoopSleeper is a noop implementation of ChannelSleeper and +// Context.UninterruptibleSleep* methods for anonymous embedding in other types +// that do not implement special behavior around sleeps. +type NoopSleeper struct{} + +// SleepStart implements ChannelSleeper.SleepStart. +func (NoopSleeper) SleepStart() <-chan struct{} { + return nil +} + +// SleepFinish implements ChannelSleeper.SleepFinish. +func (NoopSleeper) SleepFinish(success bool) {} + +// Interrupted implements ChannelSleeper.Interrupted. +func (NoopSleeper) Interrupted() bool { + return false } -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} +// UninterruptibleSleepStart implements Context.UninterruptibleSleepStart. +func (NoopSleeper) UninterruptibleSleepStart(deactivate bool) {} -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} +// UninterruptibleSleepFinish implements Context.UninterruptibleSleepFinish. +func (NoopSleeper) UninterruptibleSleepFinish(activate bool) {} -// Deadline returns zero values, meaning no deadline. +// Deadline implements context.Context.Deadline. func (NoopSleeper) Deadline() (time.Time, bool) { return time.Time{}, false } -// Done returns nil. +// Done implements context.Context.Done. func (NoopSleeper) Done() <-chan struct{} { return nil } -// Err returns nil. +// Err returns context.Context.Err. func (NoopSleeper) Err() error { return nil } diff --git a/pkg/coverage/BUILD b/pkg/coverage/BUILD new file mode 100644 index 000000000..a198e8028 --- /dev/null +++ b/pkg/coverage/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "coverage", + srcs = ["coverage.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/sync", + "//pkg/usermem", + "@io_bazel_rules_go//go/tools/coverdata", + ], +) diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go new file mode 100644 index 000000000..a4f4b2c5e --- /dev/null +++ b/pkg/coverage/coverage.go @@ -0,0 +1,172 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package coverage provides an interface through which Go coverage data can +// be collected, converted to kcov format, and exposed to userspace. +// +// Coverage can be enabled by calling bazel {build,test} with +// --collect_coverage_data and --instrumentation_filter with the desired +// coverage surface. This causes bazel to use the Go cover tool manually to +// generate instrumented files. It injects a hook that registers all coverage +// data with the coverdata package. +package coverage + +import ( + "fmt" + "io" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" + + "github.com/bazelbuild/rules_go/go/tools/coverdata" +) + +// KcovAvailable returns whether the kcov coverage interface is available. It is +// available as long as coverage is enabled for some files. +func KcovAvailable() bool { + return len(coverdata.Cover.Blocks) > 0 +} + +// coverageMu must be held while accessing coverdata.Cover. This prevents +// concurrent reads/writes from multiple threads collecting coverage data. +var coverageMu sync.RWMutex + +// once ensures that globalData is only initialized once. +var once sync.Once + +var globalData struct { + // files is the set of covered files sorted by filename. It is calculated at + // startup. + files []string + + // syntheticPCs are a set of PCs calculated at startup, where the PC + // at syntheticPCs[i][j] corresponds to file i, block j. + syntheticPCs [][]uint64 +} + +// ClearCoverageData clears existing coverage data. +func ClearCoverageData() { + coverageMu.Lock() + defer coverageMu.Unlock() + for _, counters := range coverdata.Cover.Counters { + for index := 0; index < len(counters); index++ { + atomic.StoreUint32(&counters[index], 0) + } + } +} + +var coveragePool = sync.Pool{ + New: func() interface{} { + return make([]byte, 0) + }, +} + +// ConsumeCoverageData builds and writes the collection of covered PCs. It +// returns the number of bytes written. +// +// In Linux, a kernel configuration is set that compiles the kernel with a +// custom function that is called at the beginning of every basic block, which +// updates the memory-mapped coverage information. The Go coverage tool does not +// allow us to inject arbitrary instructions into basic blocks, but it does +// provide data that we can convert to a kcov-like format and transfer them to +// userspace through a memory mapping. +// +// Note that this is not a strict implementation of kcov, which is especially +// tricky to do because we do not have the same coverage tools available in Go +// that that are available for the actual Linux kernel. In Linux, a kernel +// configuration is set that compiles the kernel with a custom function that is +// called at the beginning of every basic block to write program counters to the +// kcov memory mapping. In Go, however, coverage tools only give us a count of +// basic blocks as they are executed. Every time we return to userspace, we +// collect the coverage information and write out PCs for each block that was +// executed, providing userspace with the illusion that the kcov data is always +// up to date. For convenience, we also generate a unique synthetic PC for each +// block instead of using actual PCs. Finally, we do not provide thread-specific +// coverage data (each kcov instance only contains PCs executed by the thread +// owning it); instead, we will supply data for any file specified by -- +// instrumentation_filter. +// +// Note that we "consume", i.e. clear, coverdata when this function is run, to +// ensure that each event is only reported once. Due to the limitations of Go +// coverage tools, we reset the global coverage data every time this function is +// run. +func ConsumeCoverageData(w io.Writer) int { + once.Do(initCoverageData) + + coverageMu.Lock() + defer coverageMu.Unlock() + + total := 0 + var pcBuffer [8]byte + for fileIndex, file := range globalData.files { + counters := coverdata.Cover.Counters[file] + for index := 0; index < len(counters); index++ { + if atomic.LoadUint32(&counters[index]) == 0 { + continue + } + // Non-zero coverage data found; consume it and report as a PC. + atomic.StoreUint32(&counters[index], 0) + pc := globalData.syntheticPCs[fileIndex][index] + usermem.ByteOrder.PutUint64(pcBuffer[:], pc) + n, err := w.Write(pcBuffer[:]) + if err != nil { + if err == io.EOF { + // Simply stop writing if we encounter EOF; it's ok if we attempted to + // write more than we can hold. + return total + n + } + panic(fmt.Sprintf("Internal error writing PCs to kcov area: %v", err)) + } + total += n + } + } + + if total == 0 { + // An empty profile indicates that coverage is not enabled, in which case + // there shouldn't be any task work registered. + panic("kcov task work is registered, but no coverage data was found") + } + return total +} + +// initCoverageData initializes globalData. It should only be called once, +// before any kcov data is written. +func initCoverageData() { + // First, order all files. Then calculate synthetic PCs for every block + // (using the well-defined ordering for files as well). + for file := range coverdata.Cover.Blocks { + globalData.files = append(globalData.files, file) + } + sort.Strings(globalData.files) + + // nextSyntheticPC is the first PC that we generate for a block. + // + // This uses a standard-looking kernel range for simplicity. + // + // FIXME(b/160639712): This is only necessary because syzkaller requires + // addresses in the kernel range. If we can remove this constraint, then we + // should be able to use the actual addresses. + var nextSyntheticPC uint64 = 0xffffffff80000000 + for _, file := range globalData.files { + blocks := coverdata.Cover.Blocks[file] + thisFile := make([]uint64, 0, len(blocks)) + for range blocks { + thisFile = append(thisFile, nextSyntheticPC) + nextSyntheticPC++ // Advance. + } + globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile) + } +} diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go index c9bd40e1b..e4ae0d689 100644 --- a/pkg/cpuid/cpuid_parse_x86_test.go +++ b/pkg/cpuid/cpuid_parse_x86_test.go @@ -32,27 +32,27 @@ func kernelVersion() (int, int, error) { return 0, 0, err } - var r string + var sb strings.Builder for _, b := range u.Release { if b == 0 { break } - r += string(b) + sb.WriteByte(byte(b)) } - s := strings.Split(r, ".") + s := strings.Split(sb.String(), ".") if len(s) < 2 { - return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", r) + return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", sb.String()) } major, err := strconv.Atoi(s[0]) if err != nil { - return 0, 0, fmt.Errorf("error parsing major version %q in %q: %v", s[0], r, err) + return 0, 0, fmt.Errorf("error parsing major version %q in %q: %w", s[0], sb.String(), err) } minor, err := strconv.Atoi(s[1]) if err != nil { - return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %v", s[1], r, err) + return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %w", s[1], sb.String(), err) } return major, minor, nil diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 83bcfe220..cc6b0cdf1 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -49,7 +49,7 @@ func fixCount(n int, err error) (int, error) { // Read implements io.Reader. func (r *ReadWriter) Read(b []byte) (int, error) { - c, err := fixCount(syscall.Read(int(atomic.LoadInt64(&r.fd)), b)) + c, err := fixCount(syscall.Read(r.FD(), b)) if c == 0 && len(b) > 0 && err == nil { return 0, io.EOF } @@ -62,7 +62,7 @@ func (r *ReadWriter) Read(b []byte) (int, error) { func (r *ReadWriter) ReadAt(b []byte, off int64) (c int, err error) { for len(b) > 0 { var m int - m, err = fixCount(syscall.Pread(int(atomic.LoadInt64(&r.fd)), b, off)) + m, err = fixCount(syscall.Pread(r.FD(), b, off)) if m == 0 && err == nil { return c, io.EOF } @@ -82,7 +82,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) { var n, remaining int for remaining = len(b); remaining > 0; { woff := len(b) - remaining - n, err = syscall.Write(int(atomic.LoadInt64(&r.fd)), b[woff:]) + n, err = syscall.Write(r.FD(), b[woff:]) if n > 0 { // syscall.Write wrote some bytes. This is the common case. @@ -110,7 +110,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) { func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) { for len(b) > 0 { var m int - m, err = fixCount(syscall.Pwrite(int(atomic.LoadInt64(&r.fd)), b, off)) + m, err = fixCount(syscall.Pwrite(r.FD(), b, off)) if err != nil { break } @@ -121,6 +121,16 @@ func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) { return } +// FD returns the owned file descriptor. Ownership remains unchanged. +func (r *ReadWriter) FD() int { + return int(atomic.LoadInt64(&r.fd)) +} + +// String implements Stringer.String(). +func (r *ReadWriter) String() string { + return fmt.Sprintf("FD: %d", r.FD()) +} + // FD owns a host file descriptor. // // It is similar to os.File, with a few important distinctions: @@ -167,6 +177,23 @@ func NewFromFile(file *os.File) (*FD, error) { return New(fd), nil } +// NewFromFiles creates new FDs for each file in the slice. +func NewFromFiles(files []*os.File) ([]*FD, error) { + rv := make([]*FD, 0, len(files)) + for _, f := range files { + new, err := NewFromFile(f) + if err != nil { + // Cleanup on error. + for _, fd := range rv { + fd.Close() + } + return nil, err + } + rv = append(rv, new) + } + return rv, nil +} + // Open is equivalent to open(2). func Open(path string, openmode int, perm uint32) (*FD, error) { f, err := syscall.Open(path, openmode|syscall.O_LARGEFILE, perm) @@ -204,11 +231,6 @@ func (f *FD) Release() int { return int(atomic.SwapInt64(&f.fd, -1)) } -// FD returns the file descriptor owned by FD. FD retains ownership. -func (f *FD) FD() int { - return int(atomic.LoadInt64(&f.fd)) -} - // File converts the FD to an os.File. // // FD does not transfer ownership of the file descriptor (it will be @@ -219,7 +241,7 @@ func (f *FD) FD() int { // This operation is somewhat expensive, so care should be taken to minimize // its use. func (f *FD) File() (*os.File, error) { - fd, err := syscall.Dup(int(atomic.LoadInt64(&f.fd))) + fd, err := syscall.Dup(f.FD()) if err != nil { return nil, err } diff --git a/pkg/fdnotifier/poll_unsafe.go b/pkg/fdnotifier/poll_unsafe.go index 4225b04dd..ec2f997a2 100644 --- a/pkg/fdnotifier/poll_unsafe.go +++ b/pkg/fdnotifier/poll_unsafe.go @@ -65,8 +65,7 @@ func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask { // epollWait performs a blocking wait on epfd. // -// Preconditions: -// * len(events) > 0 +// Preconditions: len(events) > 0 func epollWait(epfd int, events []syscall.EpollEvent, msec int) (int, error) { if len(events) == 0 { panic("Empty events passed to EpollWait") diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index ec742c091..c4a3366ce 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -179,8 +179,10 @@ const ( // Connect blocks until the peer Endpoint has called Endpoint.RecvFirst(). // -// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), -// ep.SendRecv(), and ep.SendLast() have never been called. +// Preconditions: +// * ep is a client Endpoint. +// * ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and ep.SendLast() have never +// been called. func (ep *Endpoint) Connect() error { err := ep.ctrlConnect() if err == nil { @@ -192,8 +194,9 @@ func (ep *Endpoint) Connect() error { // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then // returns the datagram length specified by that call. // -// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and -// ep.SendLast() have never been called. +// Preconditions: +// * ep is a server Endpoint. +// * ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never been called. func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err @@ -211,10 +214,12 @@ func (ep *Endpoint) RecvFirst() (uint32, error) { // datagram length, then blocks until the peer Endpoint calls // Endpoint.SendRecv() or Endpoint.SendLast(). // -// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or -// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. -// If ep is a client Endpoint, ep.Connect() has previously been called and -// returned nil. +// Preconditions: +// * dataLen <= ep.DataCap(). +// * No previous call to ep.SendRecv() or ep.RecvFirst() has returned an error. +// * ep.SendLast() has never been called. +// * If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { if dataLen > ep.dataCap { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) @@ -240,10 +245,12 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or // Endpoint.RecvFirst() to return with the given datagram length. // -// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or -// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. -// If ep is a client Endpoint, ep.Connect() has previously been called and -// returned nil. +// Preconditions: +// * dataLen <= ep.DataCap(). +// * No previous call to ep.SendRecv() or ep.RecvFirst() has returned an error. +// * ep.SendLast() has never been called. +// * If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. func (ep *Endpoint) SendLast(dataLen uint32) error { if dataLen > ep.dataCap { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD new file mode 100644 index 000000000..eda82cfc1 --- /dev/null +++ b/pkg/iovec/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "iovec", + srcs = ["iovec.go"], + visibility = ["//:sandbox"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "iovec_test", + size = "small", + srcs = ["iovec_test.go"], + library = ":iovec", + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go new file mode 100644 index 000000000..dd70fe80f --- /dev/null +++ b/pkg/iovec/iovec.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package iovec provides helpers to interact with vectorized I/O on host +// system. +package iovec + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// MaxIovs is the maximum number of iovecs host platform can accept. +var MaxIovs = linux.UIO_MAXIOV + +// Builder is a builder for slice of syscall.Iovec. +type Builder struct { + iovec []syscall.Iovec + storage [8]syscall.Iovec + + // overflow tracks the last buffer when iovec length is at MaxIovs. + overflow []byte +} + +// Add adds buf to b preparing to be written. Zero-length buf won't be added. +func (b *Builder) Add(buf []byte) { + if len(buf) == 0 { + return + } + if b.iovec == nil { + b.iovec = b.storage[:0] + } + if len(b.iovec) >= MaxIovs { + b.addByAppend(buf) + return + } + b.iovec = append(b.iovec, syscall.Iovec{ + Base: &buf[0], + Len: uint64(len(buf)), + }) + // Keep the last buf if iovec is at max capacity. We will need to append to it + // for later bufs. + if len(b.iovec) == MaxIovs { + n := len(buf) + b.overflow = buf[:n:n] + } +} + +func (b *Builder) addByAppend(buf []byte) { + b.overflow = append(b.overflow, buf...) + b.iovec[len(b.iovec)-1] = syscall.Iovec{ + Base: &b.overflow[0], + Len: uint64(len(b.overflow)), + } +} + +// Build returns the final Iovec slice. The length of returned iovec will not +// excceed MaxIovs. +func (b *Builder) Build() []syscall.Iovec { + return b.iovec +} diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go new file mode 100644 index 000000000..a3900c299 --- /dev/null +++ b/pkg/iovec/iovec_test.go @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package iovec + +import ( + "bytes" + "fmt" + "syscall" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func TestBuilderEmpty(t *testing.T) { + var builder Builder + iovecs := builder.Build() + if got, want := len(iovecs), 0; got != want { + t.Errorf("len(iovecs) = %d, want %d", got, want) + } +} + +func TestBuilderBuild(t *testing.T) { + a := []byte{1, 2} + b := []byte{3, 4, 5} + + var builder Builder + builder.Add(a) + builder.Add(b) + builder.Add(nil) // Nil slice won't be added. + builder.Add([]byte{}) // Empty slice won't be added. + iovecs := builder.Build() + + if got, want := len(iovecs), 2; got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } + for i, data := range [][]byte{a, b} { + if got, want := *iovecs[i].Base, data[0]; got != want { + t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want) + } + if got, want := iovecs[i].Len, uint64(len(data)); got != want { + t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want) + } + } +} + +func TestBuilderBuildMaxIov(t *testing.T) { + for _, test := range []struct { + numIov int + }{ + { + numIov: MaxIovs - 1, + }, + { + numIov: MaxIovs, + }, + { + numIov: MaxIovs + 1, + }, + { + numIov: MaxIovs + 10, + }, + } { + name := fmt.Sprintf("numIov=%v", test.numIov) + t.Run(name, func(t *testing.T) { + var data []byte + var builder Builder + for i := 0; i < test.numIov; i++ { + buf := []byte{byte(i)} + builder.Add(buf) + data = append(data, buf...) + } + iovec := builder.Build() + + // Check the expected length of iovec. + wantNum := test.numIov + if wantNum > MaxIovs { + wantNum = MaxIovs + } + if got, want := len(iovec), wantNum; got != want { + t.Errorf("len(iovec) = %d, want %d", got, want) + } + + // Test a real read-write. + var fds [2]int + if err := unix.Pipe(fds[:]); err != nil { + t.Fatalf("Pipe: %v", err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if int(wrote) != len(data) || e != 0 { + t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data)) + } + + got := make([]byte, len(data)) + if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil { + t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got)) + } + + if !bytes.Equal(got, data) { + t.Errorf("read: got data %v, want %v", got, data) + } + }) + } +} diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md new file mode 100644 index 000000000..51d0d40e5 --- /dev/null +++ b/pkg/lisafs/README.md @@ -0,0 +1,363 @@ +# Replacing 9P + +## Background + +The Linux filesystem model consists of the following key aspects (modulo mounts, +which are outside the scope of this discussion): + +- A `struct inode` represents a "filesystem object", such as a directory or a + regular file. "Filesystem object" is most precisely defined by the practical + properties of an inode, such as an immutable type (regular file, directory, + symbolic link, etc.) and its independence from the path originally used to + obtain it. + +- A `struct dentry` represents a node in a filesystem tree. Semantically, each + dentry is immutably associated with an inode representing the filesystem + object at that position. (Linux implements optimizations involving reuse of + unreferenced dentries, which allows their associated inodes to change, but + this is outside the scope of this discussion.) + +- A `struct file` represents an open file description (hereafter FD) and is + needed to perform I/O. Each FD is immutably associated with the dentry + through which it was opened. + +The current gVisor virtual filesystem implementation (hereafter VFS1) closely +imitates the Linux design: + +- `struct inode` => `fs.Inode` + +- `struct dentry` => `fs.Dirent` + +- `struct file` => `fs.File` + +gVisor accesses most external filesystems through a variant of the 9P2000.L +protocol, including extensions for performance (`walkgetattr`) and for features +not supported by vanilla 9P2000.L (`flushf`, `lconnect`). The 9P protocol family +is inode-based; 9P fids represent a file (equivalently "file system object"), +and the protocol is structured around alternatively obtaining fids to represent +files (with `walk` and, in gVisor, `walkgetattr`) and performing operations on +those fids. + +In the sections below, a **shared** filesystem is a filesystem that is *mutably* +accessible by multiple concurrent clients, such that a **non-shared** filesystem +is a filesystem that is either read-only or accessible by only a single client. + +## Problems + +### Serialization of Path Component RPCs + +Broadly speaking, VFS1 traverses each path component in a pathname, alternating +between verifying that each traversed dentry represents an inode that represents +a searchable directory and moving to the next dentry in the path. + +In the context of a remote filesystem, the structure of this traversal means +that - modulo caching - a path involving N components requires at least N-1 +*sequential* RPCs to obtain metadata for intermediate directories, incurring +significant latency. (In vanilla 9P2000.L, 2(N-1) RPCs are required: N-1 `walk` +and N-1 `getattr`. We added the `walkgetattr` RPC to reduce this overhead.) On +non-shared filesystems, this overhead is primarily significant during +application startup; caching mitigates much of this overhead at steady state. On +shared filesystems, where correct caching requires revalidation (requiring RPCs +for each revalidated directory anyway), this overhead is consistently ruinous. + +### Inefficient RPCs + +9P is not exceptionally economical with RPCs in general. In addition to the +issue described above: + +- Opening an existing file in 9P involves at least 2 RPCs: `walk` to produce + an unopened fid representing the file, and `lopen` to open the fid. + +- Creating a file also involves at least 2 RPCs: `walk` to produce an unopened + fid representing the parent directory, and `lcreate` to create the file and + convert the fid to an open fid representing the created file. In practice, + both the Linux and gVisor 9P clients expect to have an unopened fid for the + created file (necessitating an additional `walk`), as well as attributes for + the created file (necessitating an additional `getattr`), for a total of 4 + RPCs. (In a shared filesystem, where whether a file already exists can + change between RPCs, a correct implementation of `open(O_CREAT)` would have + to alternate between these two paths (plus `clunk`ing the temporary fid + between alternations, since the nature of the `fid` differs between the two + paths). Neither Linux nor gVisor implement the required alternation, so + `open(O_CREAT)` without `O_EXCL` can spuriously fail with `EEXIST` on both.) + +- Closing (`clunk`ing) a fid requires an RPC. VFS1 issues this RPC + asynchronously in an attempt to reduce critical path latency, but scheduling + overhead makes this not clearly advantageous in practice. + +- `read` and `readdir` can return partial reads without a way to indicate EOF, + necessitating an additional final read to detect EOF. + +- Operations that affect filesystem state do not consistently return updated + filesystem state. In gVisor, the client implementation attempts to handle + this by tracking what it thinks updated state "should" be; this is complex, + and especially brittle for timestamps (which are often not arbitrarily + settable). In Linux, the client implemtation invalidates cached metadata + whenever it performs such an operation, and reloads it when a dentry + corresponding to an inode with no valid cached metadata is revalidated; this + is simple, but necessitates an additional `getattr`. + +### Dentry/Inode Ambiguity + +As noted above, 9P's documentation tends to imply that unopened fids represent +an inode. In practice, most filesystem APIs present very limited interfaces for +working with inodes at best, such that the interpretation of unopened fids +varies: + +- Linux's 9P client associates unopened fids with (dentry, uid) pairs. When + caching is enabled, it also associates each inode with the first fid opened + writably that references that inode, in order to support page cache + writeback. + +- gVisor's 9P client associates unopened fids with inodes, and also caches + opened fids in inodes in a manner similar to Linux. + +- The runsc fsgofer associates unopened fids with both "dentries" (host + filesystem paths) and "inodes" (host file descriptors); which is used + depends on the operation invoked on the fid. + +For non-shared filesystems, this confusion has resulted in correctness issues +that are (in gVisor) currently handled by a number of coarse-grained locks that +serialize renames with all other filesystem operations. For shared filesystems, +this means inconsistent behavior in the presence of concurrent mutation. + +## Design + +Almost all Linux filesystem syscalls describe filesystem resources in one of two +ways: + +- Path-based: A filesystem position is described by a combination of a + starting position and a sequence of path components relative to that + position, where the starting position is one of: + + - The VFS root (defined by mount namespace and chroot), for absolute paths + + - The VFS position of an existing FD, for relative paths passed to `*at` + syscalls (e.g. `statat`) + + - The current working directory, for relative paths passed to non-`*at` + syscalls and `*at` syscalls with `AT_FDCWD` + +- File-description-based: A filesystem object is described by an existing FD, + passed to a `f*` syscall (e.g. `fstat`). + +Many of our issues with 9P arise from its (and VFS') interposition of a model +based on inodes between the filesystem syscall API and filesystem +implementations. We propose to replace 9P with a protocol that does not feature +inodes at all, and instead closely follows the filesystem syscall API by +featuring only path-based and FD-based operations, with minimal deviations as +necessary to ameliorate deficiencies in the syscall interface (see below). This +approach addresses the issues described above: + +- Even on shared filesystems, most application filesystem syscalls are + translated to a single RPC (possibly excepting special cases described + below), which is a logical lower bound. + +- The behavior of application syscalls on shared filesystems is + straightforwardly predictable: path-based syscalls are translated to + path-based RPCs, which will re-lookup the file at that path, and FD-based + syscalls are translated to FD-based RPCs, which use an existing open file + without performing another lookup. (This is at least true on gofers that + proxy the host local filesystem; other filesystems that lack support for + e.g. certain operations on FDs may have different behavior, but this + divergence is at least still predictable and inherent to the underlying + filesystem implementation.) + +Note that this approach is only feasible in gVisor's next-generation virtual +filesystem (VFS2), which does not assume the existence of inodes and allows the +remote filesystem client to translate whole path-based syscalls into RPCs. Thus +one of the unavoidable tradeoffs associated with such a protocol vs. 9P is the +inability to construct a Linux client that is performance-competitive with +gVisor. + +### File Permissions + +Many filesystem operations are side-effectual, such that file permissions must +be checked before such operations take effect. The simplest approach to file +permission checking is for the sentry to obtain permissions from the remote +filesystem, then apply permission checks in the sentry before performing the +application-requested operation. However, this requires an additional RPC per +application syscall (which can't be mitigated by caching on shared filesystems). +Alternatively, we may delegate file permission checking to gofers. In general, +file permission checks depend on the following properties of the accessor: + +- Filesystem UID/GID + +- Supplementary GIDs + +- Effective capabilities in the accessor's user namespace (i.e. the accessor's + effective capability set) + +- All UIDs and GIDs mapped in the accessor's user namespace (which determine + if the accessor's capabilities apply to accessed files) + +We may choose to delay implementation of file permission checking delegation, +although this is potentially costly since it doubles the number of required RPCs +for most operations on shared filesystems. We may also consider compromise +options, such as only delegating file permission checks for accessors in the +root user namespace. + +### Symbolic Links + +gVisor usually interprets symbolic link targets in its VFS rather than on the +filesystem containing the symbolic link; thus e.g. a symlink to +"/proc/self/maps" on a remote filesystem resolves to said file in the sentry's +procfs rather than the host's. This implies that: + +- Remote filesystem servers that proxy filesystems supporting symlinks must + check if each path component is a symlink during path traversal. + +- Absolute symlinks require that the sentry restart the operation at its + contextual VFS root (which is task-specific and may not be on a remote + filesystem at all), so if a remote filesystem server encounters an absolute + symlink during path traversal on behalf of a path-based operation, it must + terminate path traversal and return the symlink target. + +- Relative symlinks begin target resolution in the parent directory of the + symlink, so in theory most relative symlinks can be handled automatically + during the path traversal that encounters the symlink, provided that said + traversal is supplied with the number of remaining symlinks before `ELOOP`. + However, the new path traversed by the symlink target may cross VFS mount + boundaries, such that it's only safe for remote filesystem servers to + speculatively follow relative symlinks for side-effect-free operations such + as `stat` (where the sentry can simply ignore results that are inapplicable + due to crossing mount boundaries). We may choose to delay implementation of + this feature, at the cost of an additional RPC per relative symlink (note + that even if the symlink target crosses a mount boundary, the sentry will + need to `stat` the path to the mount boundary to confirm that each traversed + component is an accessible directory); until it is implemented, relative + symlinks may be handled like absolute symlinks, by terminating path + traversal and returning the symlink target. + +The possibility of symlinks (and the possibility of a compromised sentry) means +that the sentry may issue RPCs with paths that, in the absence of symlinks, +would traverse beyond the root of the remote filesystem. For example, the sentry +may issue an RPC with a path like "/foo/../..", on the premise that if "/foo" is +a symlink then the resulting path may be elsewhere on the remote filesystem. To +handle this, path traversal must also track its current depth below the remote +filesystem root, and terminate path traversal if it would ascend beyond this +point. + +### Path Traversal + +Since path-based VFS operations will translate to path-based RPCs, filesystem +servers will need to handle path traversal. From the perspective of a given +filesystem implementation in the server, there are two basic approaches to path +traversal: + +- Inode-walk: For each path component, obtain a handle to the underlying + filesystem object (e.g. with `open(O_PATH)`), check if that object is a + symlink (as described above) and that that object is accessible by the + caller (e.g. with `fstat()`), then continue to the next path component (e.g. + with `openat()`). This ensures that the checked filesystem object is the one + used to obtain the next object in the traversal, which is intuitively + appealing. However, while this approach works for host local filesystems, it + requires features that are not widely supported by other filesystems. + +- Path-walk: For each path component, use a path-based operation to determine + if the filesystem object currently referred to by that path component is a + symlink / is accessible. This is highly portable, but suffers from quadratic + behavior (at the level of the underlying filesystem implementation, the + first path component will be traversed a number of times equal to the number + of path components in the path). + +The implementation should support either option by delegating path traversal to +filesystem implementations within the server (like VFS and the remote filesystem +protocol itself), as inode-walking is still safe, efficient, amenable to FD +caching, and implementable on non-shared host local filesystems (a sufficiently +common case as to be worth considering in the design). + +Both approaches are susceptible to race conditions that may permit sandboxed +filesystem escapes: + +- Under inode-walk, a malicious application may cause a directory to be moved + (with `rename`) during path traversal, such that the filesystem + implementation incorrectly determines whether subsequent inodes are located + in paths that should be visible to sandboxed applications. + +- Under path-walk, a malicious application may cause a non-symlink file to be + replaced with a symlink during path traversal, such that following path + operations will incorrectly follow the symlink. + +Both race conditions can, to some extent, be mitigated in filesystem server +implementations by synchronizing path traversal with the hazardous operations in +question. However, shared filesystems are frequently used to share data between +sandboxed and unsandboxed applications in a controlled way, and in some cases a +malicious sandboxed application may be able to take advantage of a hazardous +filesystem operation performed by an unsandboxed application. In some cases, +filesystem features may be available to ensure safety even in such cases (e.g. +[the new openat2() syscall](https://man7.org/linux/man-pages/man2/openat2.2.html)), +but it is not clear how to solve this problem in general. (Note that this issue +is not specific to our design; rather, it is a fundamental limitation of +filesystem sandboxing.) + +### Filesystem Multiplexing + +A given sentry may need to access multiple distinct remote filesystems (e.g. +different volumes for a given container). In many cases, there is no advantage +to serving these filesystems from distinct filesystem servers, or accessing them +through distinct connections (factors such as maximum RPC concurrency should be +based on available host resources). Therefore, the protocol should support +multiplexing of distinct filesystem trees within a single session. 9P supports +this by allowing multiple calls to the `attach` RPC to produce fids representing +distinct filesystem trees, but this is somewhat clunky; we propose a much +simpler mechanism wherein each message that conveys a path also conveys a +numeric filesystem ID that identifies a filesystem tree. + +## Alternatives Considered + +### Additional Extensions to 9P + +There are at least three conceptual aspects to 9P: + +- Wire format: messages with a 4-byte little-endian size prefix, strings with + a 2-byte little-endian size prefix, etc. Whether the wire format is worth + retaining is unclear; in particular, it's unclear that the 9P wire format + has a significant advantage over protobufs, which are substantially easier + to extend. Note that the official Go protobuf implementation is widely known + to suffer from a significant number of performance deficiencies, so if we + choose to switch to protobuf, we may need to use an alternative toolchain + such as `gogo/protobuf` (which is also widely used in the Go ecosystem, e.g. + by Kubernetes). + +- Filesystem model: fids, qids, etc. Discarding this is one of the motivations + for this proposal. + +- RPCs: Twalk, Tlopen, etc. In addition to previously-described + inefficiencies, most of these are dependent on the filesystem model and + therefore must be discarded. + +### FUSE + +The FUSE (Filesystem in Userspace) protocol is frequently used to provide +arbitrary userspace filesystem implementations to a host Linux kernel. +Unfortunately, FUSE is also inode-based, and therefore doesn't address any of +the problems we have with 9P. + +### virtio-fs + +virtio-fs is an ongoing project aimed at improving Linux VM filesystem +performance when accessing Linux host filesystems (vs. virtio-9p). In brief, it +is based on: + +- Using a FUSE client in the guest that communicates over virtio with a FUSE + server in the host. + +- Using DAX to map the host page cache into the guest. + +- Using a file metadata table in shared memory to avoid VM exits for metadata + updates. + +None of these improvements seem applicable to gVisor: + +- As explained above, FUSE is still inode-based, so it is still susceptible to + most of the problems we have with 9P. + +- Our use of host file descriptors already allows us to leverage the host page + cache for file contents. + +- Our need for shared filesystem coherence is usually based on a user + requirement that an out-of-sandbox filesystem mutation is guaranteed to be + visible by all subsequent observations from within the sandbox, or vice + versa; it's not clear that this can be guaranteed without a synchronous + signaling mechanism like an RPC. diff --git a/tools/go_marshal/marshal/BUILD b/pkg/marshal/BUILD index bacfaa5a4..aac0161fa 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/pkg/marshal/BUILD @@ -6,11 +6,10 @@ go_library( name = "marshal", srcs = [ "marshal.go", + "marshal_impl_util.go", ], visibility = [ "//:sandbox", ], - deps = [ - "//pkg/usermem", - ], + deps = ["//pkg/usermem"], ) diff --git a/tools/go_marshal/marshal/marshal.go b/pkg/marshal/marshal.go index cb2166252..d8cb44b40 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -26,9 +26,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Task provides a subset of kernel.Task, used in marshalling. We don't import -// the kernel package directly to avoid circular dependency. -type Task interface { +// CopyContext defines the memory operations required to marshal to and from +// user memory. Typically, kernel.Task is used to provide implementations for +// these operations. +type CopyContext interface { // CopyScratchBuffer provides a task goroutine-local scratch buffer. See // kernel.CopyScratchBuffer. CopyScratchBuffer(size int) []byte @@ -58,18 +59,12 @@ type Marshallable interface { // likely make use of the type of these fields). SizeBytes() int - // MarshalBytes serializes a copy of a type to dst. dst may be smaller than - // SizeBytes(), which results in a part of the struct being marshalled. Note - // that this may have unexpected results for non-packed types, as implicit - // padding needs to be taken into account when reasoning about how much of - // the type is serialized. + // MarshalBytes serializes a copy of a type to dst. + // Precondition: dst must be at least SizeBytes() in length. MarshalBytes(dst []byte) - // UnmarshalBytes deserializes a type from src. src may be smaller than - // SizeBytes(), which results in a partially deserialized struct. Note that - // this may have unexpected results for non-packed types, as implicit - // padding needs to be taken into account when reasoning about how much of - // the type is deserialized. + // UnmarshalBytes deserializes a type from src. + // Precondition: src must be at least SizeBytes() in length. UnmarshalBytes(src []byte) // Packed returns true if the marshalled size of the type is the same as the @@ -89,8 +84,8 @@ type Marshallable interface { // representation to the dst buffer. This is only safe to do when the type // has no implicit padding, see Marshallable.Packed. When Packed would // return false, MarshalUnsafe should fall back to the safer but slower - // MarshalBytes. dst may be smaller than SizeBytes(), see comment for - // MarshalBytes for implications. + // MarshalBytes. + // Precondition: dst must be at least SizeBytes() in length. MarshalUnsafe(dst []byte) // UnmarshalUnsafe deserializes a type by directly copying to the underlying @@ -99,8 +94,8 @@ type Marshallable interface { // This allows much faster unmarshalling of types which have no implicit // padding, see Marshallable.Packed. When Packed would return false, // UnmarshalUnsafe should fall back to the safer but slower unmarshal - // mechanism implemented in UnmarshalBytes. src may be smaller than - // SizeBytes(), see comment for UnmarshalBytes for implications. + // mechanism implemented in UnmarshalBytes. + // Precondition: src must be at least SizeBytes() in length. UnmarshalUnsafe(src []byte) // CopyIn deserializes a Marshallable type from a task's memory. This may @@ -113,7 +108,7 @@ type Marshallable interface { // If the copy-in from the task memory is only partially successful, CopyIn // should still attempt to deserialize as much data as possible. See comment // for UnmarshalBytes. - CopyIn(task Task, addr usermem.Addr) (int, error) + CopyIn(cc CopyContext, addr usermem.Addr) (int, error) // CopyOut serializes a Marshallable type to a task's memory. This may only // be called from a task goroutine. This is more efficient than calling @@ -124,7 +119,7 @@ type Marshallable interface { // The copy-out to the task memory may be partially successful, in which // case CopyOut returns how much data was serialized. See comment for // MarshalBytes for implications. - CopyOut(task Task, addr usermem.Addr) (int, error) + CopyOut(cc CopyContext, addr usermem.Addr) (int, error) // CopyOutN is like CopyOut, but explicitly requests a partial // copy-out. Note that this may yield unexpected results for non-packed @@ -132,7 +127,7 @@ type Marshallable interface { // comment on MarshalBytes. // // The limit must be less than or equal to SizeBytes(). - CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) + CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) } // go-marshal generates additional functions for a type based on additional @@ -149,21 +144,23 @@ type Marshallable interface { // // Generates four additional functions for marshalling slices of Foos like this: // -// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's -// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a -// // []Foo in a loop. +// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It +// // might be more efficient that repeatedly calling Foo.MarshalUnsafe +// // over a []Foo in a loop if the type is Packed. +// // Preconditions: dst must be at least len(src)*Foo.SizeBytes() in length. // func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... } // -// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's -// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a -// // []Foo in a loop. +// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It +// // might be more efficient that repeatedly calling Foo.UnmarshalUnsafe +// // over a []Foo in a loop if the type is Packed. +// // Preconditions: src must be at least len(dst)*Foo.SizeBytes() in length. // func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } // // // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. -// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... } +// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... } // // // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. -// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... } +// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... } // // The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where // %s is the first argument to the slice clause. This directive is not supported @@ -178,10 +175,10 @@ type Marshallable interface { // This is only valid on newtypes on primitives, and causes the generated // functions to accept slices of the inner type instead: // -// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... } +// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... } // // Without "inner", they would instead be: // -// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... } +// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... } // // This may help avoid a cast depending on how the generated functions are used. diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go new file mode 100644 index 000000000..ea75e09f2 --- /dev/null +++ b/pkg/marshal/marshal_impl_util.go @@ -0,0 +1,78 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package marshal + +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// StubMarshallable implements the Marshallable interface. +// StubMarshallable is a convenient embeddable type for satisfying the +// marshallable interface, but provides no actual implementation. It is +// useful when the marshallable interface needs to be implemented manually, +// but the caller doesn't require the full marshallable interface. +type StubMarshallable struct{} + +// WriteTo implements Marshallable.WriteTo. +func (StubMarshallable) WriteTo(w io.Writer) (n int64, err error) { + panic("Please implement your own WriteTo function") +} + +// SizeBytes implements Marshallable.SizeBytes. +func (StubMarshallable) SizeBytes() int { + panic("Please implement your own SizeBytes function") +} + +// MarshalBytes implements Marshallable.MarshalBytes. +func (StubMarshallable) MarshalBytes(dst []byte) { + panic("Please implement your own MarshalBytes function") +} + +// UnmarshalBytes implements Marshallable.UnmarshalBytes. +func (StubMarshallable) UnmarshalBytes(src []byte) { + panic("Please implement your own UnmarshalBytes function") +} + +// Packed implements Marshallable.Packed. +func (StubMarshallable) Packed() bool { + panic("Please implement your own Packed function") +} + +// MarshalUnsafe implements Marshallable.MarshalUnsafe. +func (StubMarshallable) MarshalUnsafe(dst []byte) { + panic("Please implement your own MarshalUnsafe function") +} + +// UnmarshalUnsafe implements Marshallable.UnmarshalUnsafe. +func (StubMarshallable) UnmarshalUnsafe(src []byte) { + panic("Please implement your own UnmarshalUnsafe function") +} + +// CopyIn implements Marshallable.CopyIn. +func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) { + panic("Please implement your own CopyIn function") +} + +// CopyOut implements Marshallable.CopyOut. +func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) { + panic("Please implement your own CopyOut function") +} + +// CopyOutN implements Marshallable.CopyOutN. +func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) { + panic("Please implement your own CopyOutN function") +} diff --git a/tools/go_marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD index cc08ba63a..d77a11c79 100644 --- a/tools/go_marshal/primitive/BUILD +++ b/pkg/marshal/primitive/BUILD @@ -12,7 +12,8 @@ go_library( "//:sandbox", ], deps = [ + "//pkg/context", + "//pkg/marshal", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go new file mode 100644 index 000000000..4b342de6b --- /dev/null +++ b/pkg/marshal/primitive/primitive.go @@ -0,0 +1,349 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package primitive defines marshal.Marshallable implementations for primitive +// types. +package primitive + +import ( + "io" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Int8 is a marshal.Marshallable implementation for int8. +// +// +marshal slice:Int8Slice:inner +type Int8 int8 + +// Uint8 is a marshal.Marshallable implementation for uint8. +// +// +marshal slice:Uint8Slice:inner +type Uint8 uint8 + +// Int16 is a marshal.Marshallable implementation for int16. +// +// +marshal slice:Int16Slice:inner +type Int16 int16 + +// Uint16 is a marshal.Marshallable implementation for uint16. +// +// +marshal slice:Uint16Slice:inner +type Uint16 uint16 + +// Int32 is a marshal.Marshallable implementation for int32. +// +// +marshal slice:Int32Slice:inner +type Int32 int32 + +// Uint32 is a marshal.Marshallable implementation for uint32. +// +// +marshal slice:Uint32Slice:inner +type Uint32 uint32 + +// Int64 is a marshal.Marshallable implementation for int64. +// +// +marshal slice:Int64Slice:inner +type Int64 int64 + +// Uint64 is a marshal.Marshallable implementation for uint64. +// +// +marshal slice:Uint64Slice:inner +type Uint64 uint64 + +// ByteSlice is a marshal.Marshallable implementation for []byte. +// This is a convenience wrapper around a dynamically sized type, and can't be +// embedded in other marshallable types because it breaks assumptions made by +// go-marshal internals. It violates the "no dynamically-sized types" +// constraint of the go-marshal library. +type ByteSlice []byte + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (b *ByteSlice) SizeBytes() int { + return len(*b) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (b *ByteSlice) MarshalBytes(dst []byte) { + copy(dst, *b) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (b *ByteSlice) UnmarshalBytes(src []byte) { + copy(*b, src) +} + +// Packed implements marshal.Marshallable.Packed. +func (b *ByteSlice) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *ByteSlice) MarshalUnsafe(dst []byte) { + b.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *ByteSlice) UnmarshalUnsafe(src []byte) { + b.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + return cc.CopyInBytes(addr, *b) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + return cc.CopyOutBytes(addr, *b) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + return cc.CopyOutBytes(addr, (*b)[:limit]) +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(*b) + return int64(n), err +} + +var _ marshal.Marshallable = (*ByteSlice)(nil) + +// Below, we define some convenience functions for marshalling primitive types +// using the newtypes above, without requiring superfluous casts. + +// 8-bit integers + +// CopyInt8In is a convenient wrapper for copying in an int8 from the task's +// memory. +func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) { + var buf Int8 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int8(buf) + return n, nil +} + +// CopyInt8Out is a convenient wrapper for copying out an int8 to the task's +// memory. +func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) { + srcP := Int8(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint8In is a convenient wrapper for copying in a uint8 from the task's +// memory. +func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) { + var buf Uint8 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint8(buf) + return n, nil +} + +// CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's +// memory. +func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) { + srcP := Uint8(src) + return srcP.CopyOut(cc, addr) +} + +// 16-bit integers + +// CopyInt16In is a convenient wrapper for copying in an int16 from the task's +// memory. +func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) { + var buf Int16 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int16(buf) + return n, nil +} + +// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's +// memory. +func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) { + srcP := Int16(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's +// memory. +func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) { + var buf Uint16 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint16(buf) + return n, nil +} + +// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's +// memory. +func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) { + srcP := Uint16(src) + return srcP.CopyOut(cc, addr) +} + +// 32-bit integers + +// CopyInt32In is a convenient wrapper for copying in an int32 from the task's +// memory. +func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) { + var buf Int32 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int32(buf) + return n, nil +} + +// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's +// memory. +func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) { + srcP := Int32(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's +// memory. +func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) { + var buf Uint32 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint32(buf) + return n, nil +} + +// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's +// memory. +func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) { + srcP := Uint32(src) + return srcP.CopyOut(cc, addr) +} + +// 64-bit integers + +// CopyInt64In is a convenient wrapper for copying in an int64 from the task's +// memory. +func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) { + var buf Int64 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int64(buf) + return n, nil +} + +// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's +// memory. +func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) { + srcP := Int64(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's +// memory. +func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) { + var buf Uint64 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint64(buf) + return n, nil +} + +// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's +// memory. +func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) { + srcP := Uint64(src) + return srcP.CopyOut(cc, addr) +} + +// CopyByteSliceIn is a convenient wrapper for copying in a []byte from the +// task's memory. +func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) { + var buf ByteSlice + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = []byte(buf) + return n, nil +} + +// CopyByteSliceOut is a convenient wrapper for copying out a []byte to the +// task's memory. +func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) { + srcP := ByteSlice(src) + return srcP.CopyOut(cc, addr) +} + +// CopyStringIn is a convenient wrapper for copying in a string from the +// task's memory. +func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) { + var buf ByteSlice + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = string(buf) + return n, nil +} + +// CopyStringOut is a convenient wrapper for copying out a string to the task's +// memory. +func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) { + srcP := ByteSlice(src) + return srcP.CopyOut(cc, addr) +} + +// IOCopyContext wraps an object implementing usermem.IO to implement +// marshal.CopyContext. +type IOCopyContext struct { + Ctx context.Context + IO usermem.IO + Opts usermem.IOOpts +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) +} diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index 5b0e4143a..a8fcb2e19 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "merkletree", srcs = ["merkletree.go"], + visibility = ["//pkg/sentry:internal"], deps = ["//pkg/usermem"], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 906f67943..4b4f9bd52 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -16,7 +16,9 @@ package merkletree import ( + "bytes" "crypto/sha256" + "fmt" "io" "gvisor.dev/gvisor/pkg/usermem" @@ -27,72 +29,153 @@ const ( sha256DigestSize = 32 ) -// Size defines the scale of a Merkle tree. -type Size struct { +// DigestSize returns the size (in bytes) of a digest. +// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). +func DigestSize() int { + return sha256DigestSize +} + +// Layout defines the scale of a Merkle tree. +type Layout struct { // blockSize is the size of a data block to be hashed. blockSize int64 // digestSize is the size of a generated hash. digestSize int64 - // hashesPerBlock is the number of hashes in a block. For example, if - // blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128 - // hashesPerBlock. Therefore 128 hashes in a lower level will be put into a - // block and generate a single hash in an upper level. - hashesPerBlock int64 - // levelStart is the start block index of each level. The number of levels in - // the tree is the length of the slice. The leafs (level 0) are hashes of - // blocks in the input data. The levels above are hashes of lower level - // hashes. The highest level is the root hash. - levelStart []int64 + // levelOffset contains the offset of the begnning of each level in + // bytes. The number of levels in the tree is the length of the slice. + // The leaf nodes (level 0) contain hashes of blocks of the input data. + // Each level N contains hashes of the blocks in level N-1. The highest + // level is the root hash. + levelOffset []int64 } -// MakeSize initializes and returns a new Size object describing the structure -// of a tree. dataSize specifies the number of the file system size in bytes. -func MakeSize(dataSize int64) Size { - size := Size{ +// InitLayout initializes and returns a new Layout object describing the structure +// of a tree. dataSize specifies the size of input data in bytes. +func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { + layout := Layout{ blockSize: usermem.PageSize, // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). - digestSize: sha256DigestSize, - hashesPerBlock: usermem.PageSize / sha256DigestSize, + digestSize: sha256DigestSize, + } + + // treeStart is the offset (in bytes) of the first level of the tree in + // the file. If data and tree are in different files, treeStart should + // be zero. If data is in the same file as the tree, treeStart points + // to the block after the last data block (which may be zero-padded). + var treeStart int64 + if dataAndTreeInSameFile { + treeStart = dataSize + if dataSize%layout.blockSize != 0 { + treeStart += layout.blockSize - dataSize%layout.blockSize + } } - numBlocks := (dataSize + size.blockSize - 1) / size.blockSize - level := int64(0) + + numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize + level := 0 offset := int64(0) - // Calcuate the number of levels in the Merkle tree and the beginning offset - // of each level. Level 0 is the level directly above the data blocks, while - // level NumLevels - 1 is the root. + // Calculate the number of levels in the Merkle tree and the beginning + // offset of each level. Level 0 consists of the leaf nodes that + // contain the hashes of the data blocks, while level numLevels - 1 is + // the root. for numBlocks > 1 { - size.levelStart = append(size.levelStart, offset) + layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) // Round numBlocks up to fill up a block. - numBlocks += (size.hashesPerBlock - numBlocks%size.hashesPerBlock) % size.hashesPerBlock - offset += numBlocks / size.hashesPerBlock - numBlocks = numBlocks / size.hashesPerBlock + numBlocks += (layout.hashesPerBlock() - numBlocks%layout.hashesPerBlock()) % layout.hashesPerBlock() + offset += numBlocks / layout.hashesPerBlock() + numBlocks = numBlocks / layout.hashesPerBlock() level++ } - size.levelStart = append(size.levelStart, offset) - return size + layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) + + return layout +} + +// hashesPerBlock() returns the number of digests in each block. For example, +// if blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128 +// hashesPerBlock. Therefore 128 hashes in one level will be combined in one +// hash in the level above. +func (layout Layout) hashesPerBlock() int64 { + return layout.blockSize / layout.digestSize +} + +// numLevels returns the total number of levels in the Merkle tree. +func (layout Layout) numLevels() int { + return len(layout.levelOffset) +} + +// rootLevel returns the level of the root hash. +func (layout Layout) rootLevel() int { + return layout.numLevels() - 1 +} + +// digestOffset finds the offset of a digest from the beginning of the tree. +// The target digest is at level of the tree, with index from the beginning of +// the current level. +func (layout Layout) digestOffset(level int, index int64) int64 { + return layout.levelOffset[level] + index*layout.digestSize +} + +// blockOffset finds the offset of a block from the beginning of the tree. The +// target block is at level of the tree, with index from the beginning of the +// current level. +func (layout Layout) blockOffset(level int, index int64) int64 { + return layout.levelOffset[level] + index*layout.blockSize } // Generate constructs a Merkle tree for the contents of data. The output is // written to treeWriter. The treeReader should be able to read the tree after // it has been written. That is, treeWriter and treeReader should point to the // same underlying data but have separate cursors. -func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) { - size := MakeSize(dataSize) +// Generate will modify the cursor for data, but always restores it to its +// original position upon exit. The cursor for tree is modified and not +// restored. +func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, treeWriter io.WriteSeeker, dataAndTreeInSameFile bool) ([]byte, error) { + layout := InitLayout(dataSize, dataAndTreeInSameFile) - numBlocks := (dataSize + size.blockSize - 1) / size.blockSize + numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize + + // If the data is in the same file as the tree, zero pad the last data + // block. + bytesInLastBlock := dataSize % layout.blockSize + if dataAndTreeInSameFile && bytesInLastBlock != 0 { + zeroBuf := make([]byte, layout.blockSize-bytesInLastBlock) + if _, err := treeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF { + return nil, err + } + if _, err := treeWriter.Write(zeroBuf); err != nil { + return nil, err + } + } + + // Store the current offset, so we can set it back once verification + // finishes. + origOffset, err := data.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + defer data.Seek(origOffset, io.SeekStart) + + // Read from the beginning of both data and treeReader. + if _, err := data.Seek(0, io.SeekStart); err != nil && err != io.EOF { + return nil, err + } + + if _, err := treeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF { + return nil, err + } var root []byte - for level := 0; level < len(size.levelStart); level++ { + for level := 0; level < layout.numLevels(); level++ { for i := int64(0); i < numBlocks; i++ { - buf := make([]byte, size.blockSize) + buf := make([]byte, layout.blockSize) var ( n int err error ) if level == 0 { - // Read data block from the target file since level 0 is directly above - // the raw data block. + // Read data block from the target file since level 0 includes hashes + // of blocks in the input data. n, err = data.Read(buf) } else { // Read data block from the tree file since levels higher than 0 are @@ -112,7 +195,7 @@ func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter i // Hash the bytes in buf. digest := sha256.Sum256(buf) - if level == len(size.levelStart)-1 { + if level == layout.rootLevel() { root = digest[:] } @@ -121,15 +204,169 @@ func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter i return nil, err } } - // If the genereated digests do not round up to a block, zero-padding the + // If the generated digests do not round up to a block, zero-padding the // remaining of the last block. But no need to do so for root. - if level != len(size.levelStart)-1 && numBlocks%size.hashesPerBlock != 0 { - zeroBuf := make([]byte, size.blockSize-(numBlocks%size.hashesPerBlock)*size.digestSize) + if level != layout.rootLevel() && numBlocks%layout.hashesPerBlock() != 0 { + zeroBuf := make([]byte, layout.blockSize-(numBlocks%layout.hashesPerBlock())*layout.digestSize) if _, err := treeWriter.Write(zeroBuf[:]); err != nil { return nil, err } } - numBlocks = (numBlocks + size.hashesPerBlock - 1) / size.hashesPerBlock + numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock() } return root, nil } + +// Verify verifies the content read from data with offset. The content is +// verified against tree. If content spans across multiple blocks, each block is +// verified. Verification fails if the hash of the data does not match the tree +// at any level, or if the final root hash does not match expectedRoot. +// Once the data is verified, it will be written using w. +// Verify will modify the cursor for data, but always restores it to its +// original position upon exit. The cursor for tree is modified and not +// restored. +func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) { + if readSize <= 0 { + return 0, fmt.Errorf("Unexpected read size: %d", readSize) + } + layout := InitLayout(int64(dataSize), dataAndTreeInSameFile) + + // Calculate the index of blocks that includes the target range in input + // data. + firstDataBlock := readOffset / layout.blockSize + lastDataBlock := (readOffset + readSize - 1) / layout.blockSize + + // Store the current offset, so we can set it back once verification + // finishes. + origOffset, err := data.Seek(0, io.SeekCurrent) + if err != nil { + return 0, fmt.Errorf("Find current data offset failed: %v", err) + } + defer data.Seek(origOffset, io.SeekStart) + + // Move to the first block that contains target data. + if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil { + return 0, fmt.Errorf("Seek to datablock start failed: %v", err) + } + + buf := make([]byte, layout.blockSize) + var readErr error + total := int64(0) + for i := firstDataBlock; i <= lastDataBlock; i++ { + // Read a block that includes all or part of target range in + // input data. + bytesRead, err := data.Read(buf) + readErr = err + // If at the end of input data and all previous blocks are + // verified, return the verified input data and EOF. + if readErr == io.EOF && bytesRead == 0 { + break + } + if readErr != nil && readErr != io.EOF { + return 0, fmt.Errorf("Read from data failed: %v", err) + } + // If this is the end of file, zero the remaining bytes in buf, + // otherwise they are still from the previous block. + // TODO(b/162908070): Investigate possible issues with zero + // padding the data. + if bytesRead < len(buf) { + for j := bytesRead; j < len(buf); j++ { + buf[j] = 0 + } + } + if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil { + return 0, err + } + // startOff is the beginning of the read range within the + // current data block. Note that for all blocks other than the + // first, startOff should be 0. + startOff := int64(0) + if i == firstDataBlock { + startOff = readOffset % layout.blockSize + } + // endOff is the end of the read range within the current data + // block. Note that for all blocks other than the last, endOff + // should be the block size. + endOff := layout.blockSize + if i == lastDataBlock { + endOff = (readOffset+readSize-1)%layout.blockSize + 1 + } + // If the provided size exceeds the end of input data, we should + // only copy the parts in buf that's part of input data. + if startOff > int64(bytesRead) { + startOff = int64(bytesRead) + } + if endOff > int64(bytesRead) { + endOff = int64(bytesRead) + } + n, err := w.Write(buf[startOff:endOff]) + if err != nil { + return total, err + } + total += int64(n) + + } + return total, readErr +} + +// verifyBlock verifies a block against tree. index is the number of block in +// original data. The block is verified through each level of the tree. It +// fails if the calculated hash from block is different from any level of +// hashes stored in tree. And the final root hash is compared with +// expectedRoot. verifyBlock modifies the cursor for tree. Users needs to +// maintain the cursor if intended. +func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error { + if len(dataBlock) != int(layout.blockSize) { + return fmt.Errorf("incorrect block size") + } + + expectedDigest := make([]byte, layout.digestSize) + treeBlock := make([]byte, layout.blockSize) + var digest []byte + for level := 0; level < layout.numLevels(); level++ { + // Calculate hash. + if level == 0 { + digestArray := sha256.Sum256(dataBlock) + digest = digestArray[:] + } else { + // Read a block in previous level that contains the + // hash we just generated, and generate a next level + // hash from it. + if _, err := tree.Seek(layout.blockOffset(level-1, blockIndex), io.SeekStart); err != nil { + return err + } + if _, err := tree.Read(treeBlock); err != nil { + return err + } + digestArray := sha256.Sum256(treeBlock) + digest = digestArray[:] + } + + // Move to stored hash for the current block, read the digest + // and store in expectedDigest. + if _, err := tree.Seek(layout.digestOffset(level, blockIndex), io.SeekStart); err != nil { + return err + } + if _, err := tree.Read(expectedDigest); err != nil { + return err + } + + if !bytes.Equal(digest, expectedDigest) { + return fmt.Errorf("Verification failed") + } + + // If this is the root layer, no need to generate next level + // hash. + if level == layout.rootLevel() { + break + } + blockIndex = blockIndex / layout.hashesPerBlock() + } + + // Verification for the tree succeeded. Now compare the root hash in the + // tree with expectedRoot. + if !bytes.Equal(digest[:], expectedRoot) { + return fmt.Errorf("Verification failed") + } + return nil +} diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index 7344db0b6..daaca759a 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -17,106 +17,402 @@ package merkletree import ( "bytes" "fmt" + "io" + "math/rand" "testing" + "time" "gvisor.dev/gvisor/pkg/usermem" ) -func TestSize(t *testing.T) { +func TestLayout(t *testing.T) { testCases := []struct { - dataSize int64 - expectedLevelStart []int64 + dataSize int64 + dataAndTreeInSameFile bool + expectedLevelOffset []int64 }{ { - dataSize: 100, - expectedLevelStart: []int64{0}, + dataSize: 100, + dataAndTreeInSameFile: false, + expectedLevelOffset: []int64{0}, }, { - dataSize: 1000000, - expectedLevelStart: []int64{0, 2, 3}, + dataSize: 100, + dataAndTreeInSameFile: true, + expectedLevelOffset: []int64{usermem.PageSize}, }, { - dataSize: 4096 * int64(usermem.PageSize), - expectedLevelStart: []int64{0, 32, 33}, + dataSize: 1000000, + dataAndTreeInSameFile: false, + expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, + }, + { + dataSize: 1000000, + dataAndTreeInSameFile: true, + expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, + }, + { + dataSize: 4096 * int64(usermem.PageSize), + dataAndTreeInSameFile: false, + expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, + }, + { + dataSize: 4096 * int64(usermem.PageSize), + dataAndTreeInSameFile: true, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, } for _, tc := range testCases { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - s := MakeSize(tc.dataSize) - if s.blockSize != int64(usermem.PageSize) { - t.Errorf("got blockSize %d, want %d", s.blockSize, usermem.PageSize) + l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + if l.blockSize != int64(usermem.PageSize) { + t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - if s.digestSize != sha256DigestSize { - t.Errorf("got digestSize %d, want %d", s.digestSize, sha256DigestSize) + if l.digestSize != sha256DigestSize { + t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } - if len(s.levelStart) != len(tc.expectedLevelStart) { - t.Errorf("got levels %d, want %d", len(s.levelStart), len(tc.expectedLevelStart)) + if l.numLevels() != len(tc.expectedLevelOffset) { + t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) } - for i := 0; i < len(s.levelStart) && i < len(tc.expectedLevelStart); i++ { - if s.levelStart[i] != tc.expectedLevelStart[i] { - t.Errorf("got levelStart[%d] %d, want %d", i, s.levelStart[i], tc.expectedLevelStart[i]) + for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ { + if l.levelOffset[i] != tc.expectedLevelOffset[i] { + t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) } } }) } } +// bytesReadWriter is used to read from/write to/seek in a byte array. Unlike +// bytes.Buffer, it keeps the whole buffer during read so that it can be reused. +type bytesReadWriter struct { + // bytes contains the underlying byte array. + bytes []byte + // readPos is the currently location for Read. Write always appends to + // the end of the array. + readPos int +} + +func (brw *bytesReadWriter) Write(p []byte) (int, error) { + brw.bytes = append(brw.bytes, p...) + return len(p), nil +} + +func (brw *bytesReadWriter) Read(p []byte) (int, error) { + if brw.readPos >= len(brw.bytes) { + return 0, io.EOF + } + bytesRead := copy(p, brw.bytes[brw.readPos:]) + brw.readPos += bytesRead + if bytesRead < len(p) { + return bytesRead, io.EOF + } + return bytesRead, nil +} + +func (brw *bytesReadWriter) Seek(offset int64, whence int) (int64, error) { + off := offset + if whence == io.SeekCurrent { + off += int64(brw.readPos) + } + if whence == io.SeekEnd { + off += int64(len(brw.bytes)) + } + if off < 0 { + panic("seek with negative offset") + } + if off >= int64(len(brw.bytes)) { + return 0, io.EOF + } + brw.readPos = int(off) + return off, nil +} + func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - dataSize int - startWith []byte + data []byte expectedRoot []byte }{ { - dataSize: usermem.PageSize, - startWith: nil, + data: bytes.Repeat([]byte{0}, usermem.PageSize), expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167}, }, { - dataSize: 128*usermem.PageSize + 1, - startWith: nil, + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132}, }, { - dataSize: 1, - startWith: []byte{'a'}, + data: []byte{'a'}, expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210}, }, { - dataSize: 1, - startWith: []byte{'1'}, - expectedRoot: []byte{74, 35, 103, 179, 176, 149, 254, 112, 42, 65, 104, 66, 119, 56, 133, 124, 228, 15, 65, 161, 150, 0, 117, 174, 242, 34, 115, 115, 218, 37, 3, 105}, + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90}, }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - var ( - data bytes.Buffer - tree bytes.Buffer - ) - - startSize := len(tc.startWith) - _, err := data.Write(tc.startWith) - if err != nil { - t.Fatalf("Failed to write to data: %v", err) - } - _, err = data.Write(make([]byte, tc.dataSize-startSize)) - if err != nil { - t.Fatalf("Failed to write to data: %v", err) - } + t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + var root []byte + var err error + if dataAndTreeInSameFile { + tree.Write(tc.data) + root, err = Generate(&tree, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) + } else { + root, err = Generate(&bytesReadWriter{ + bytes: tc.data, + }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) + } + if err != nil { + t.Fatalf("Got err: %v, want nil", err) + } - root, err := Generate(&data, int64(tc.dataSize), &tree, &tree) - if err != nil { - t.Fatalf("Generate failed: %v", err) + if !bytes.Equal(root, tc.expectedRoot) { + t.Errorf("Got root: %v, want %v", root, tc.expectedRoot) + } } + }) + } +} - if !bytes.Equal(root, tc.expectedRoot) { - t.Errorf("Unexpected root") +func TestVerify(t *testing.T) { + // The input data has size dataSize. The portion to be verified ranges from + // verifyStart with verifySize. A bit is flipped in outOfRangeByteIndex to + // confirm that modifications outside the verification range does not cause + // issue. And a bit is flipped in modifyByte to confirm that + // modifications in the verification range is caught during verification. + testCases := []struct { + dataSize int64 + verifyStart int64 + verifySize int64 + // A byte in input data is modified during the test. If the + // modified byte falls in verification range, Verify should + // fail, otherwise Verify should still succeed. + modifyByte int64 + shouldSucceed bool + }{ + // Verify range start outside the data range should fail. + { + dataSize: usermem.PageSize, + verifyStart: usermem.PageSize, + verifySize: 1, + modifyByte: 0, + shouldSucceed: false, + }, + // Verifying range is valid if it starts inside data and ends + // outside data range, in that case start to the end of data is + // verified. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 2 * usermem.PageSize, + modifyByte: 0, + shouldSucceed: false, + }, + // Invalid verify range (negative size) should fail. + { + dataSize: usermem.PageSize, + verifyStart: 1, + verifySize: -1, + modifyByte: 0, + shouldSucceed: false, + }, + // Invalid verify range (0 size) should fail. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + shouldSucceed: false, + }, + // The test cases below use a block-aligned verify range. + // Modifying a byte in the verified range should cause verify + // to fail. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + shouldSucceed: false, + }, + // Modifying a byte before the verified range should not cause + // verify to fail. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4*usermem.PageSize - 1, + shouldSucceed: true, + }, + // Modifying a byte after the verified range should not cause + // verify to fail. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 5 * usermem.PageSize, + shouldSucceed: true, + }, + // The tests below use a non-block-aligned verify range. + // Modifying a byte at strat of verify range should cause + // verify to fail. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + shouldSucceed: false, + }, + // Modifying a byte at the end of verify range should cause + // verify to fail. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + shouldSucceed: false, + }, + // Modifying a byte in the middle verified block should cause + // verify to fail. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + shouldSucceed: false, + }, + // Modifying a byte in the first block in the verified range + // should cause verify to fail, even the modified bit itself is + // out of verify range. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + shouldSucceed: false, + }, + // Modifying a byte in the last block in the verified range + // should cause verify to fail, even the modified bit itself is + // out of verify range. + { + dataSize: 8 * usermem.PageSize, + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + shouldSucceed: false, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%d", tc.modifyByte), func(t *testing.T) { + data := make([]byte, tc.dataSize) + // Generate random bytes in data. + rand.Read(data) + + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + var root []byte + var err error + if dataAndTreeInSameFile { + tree.Write(data) + root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile) + } else { + root, err = Generate(&bytesReadWriter{ + bytes: data, + }, int64(tc.dataSize), &tree, &tree, false /* dataAndTreeInSameFile */) + } + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + + // Flip a bit in data and checks Verify results. + var buf bytes.Buffer + data[tc.modifyByte] ^= 1 + if tc.shouldSucceed { + n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + } else { + if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { + t.Errorf("Verification succeeded when expected to fail") + } + } } }) } } + +func TestVerifyRandom(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + // Use a random dataSize. Minimum size 2 so that we can pick a random + // portion from it. + dataSize := rand.Int63n(200*usermem.PageSize) + 2 + data := make([]byte, dataSize) + // Generate random bytes in data. + rand.Read(data) + + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + var root []byte + var err error + if dataAndTreeInSameFile { + tree.Write(data) + root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile) + } else { + root, err = Generate(&bytesReadWriter{ + bytes: data, + }, int64(dataSize), &tree, &tree, dataAndTreeInSameFile) + } + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 + + var buf bytes.Buffer + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + + buf.Reset() + // Flip a random bit in randPortion, and check that verification fails. + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 + + if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { + t.Errorf("Verification succeeded for modified data") + } + } +} diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index 64aa365ce..d012c5734 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -106,8 +106,8 @@ type customUint64Metric struct { // after Initialized. // // Preconditions: -// * name must be globally unique. -// * Initialize/Disable have not been called. +// * name must be globally unique. +// * Initialize/Disable have not been called. func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error { if initialized { return ErrInitializationDone @@ -221,7 +221,7 @@ var ( // EmitMetricUpdate is thread-safe. // // Preconditions: -// * Initialize has been called. +// * Initialize has been called. func EmitMetricUpdate() { emitMu.Lock() defer emitMu.Unlock() diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 2ee07b664..28fe081d6 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -54,6 +54,8 @@ func (c *Client) newFile(fid FID) *clientFile { // // This proxies all of the interfaces found in file.go. type clientFile struct { + DisallowServerCalls + // client is the originating client. client *Client @@ -283,6 +285,39 @@ func (c *clientFile) Close() error { return nil } +// SetAttrClose implements File.SetAttrClose. +func (c *clientFile) SetAttrClose(valid SetAttrMask, attr SetAttr) error { + if !versionSupportsTsetattrclunk(c.client.version) { + setAttrErr := c.SetAttr(valid, attr) + + // Try to close file even in case of failure above. Since the state of the + // file is unknown to the caller, it will not attempt to close the file + // again. + if err := c.Close(); err != nil { + return err + } + + return setAttrErr + } + + // Avoid double close. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return syscall.EBADF + } + + // Send the message. + if err := c.client.sendRecv(&Tsetattrclunk{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattrclunk{}); err != nil { + // If an error occurred, we toss away the FID. This isn't ideal, + // but I'm not sure what else makes sense in this context. + log.Warningf("Tsetattrclunk failed, losing FID %v: %v", c.fid, err) + return err + } + + // Return the FID to the pool. + c.client.fidPool.Put(uint64(c.fid)) + return nil +} + // Open implements File.Open. func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) { if atomic.LoadUint32(&c.closed) != 0 { @@ -681,6 +716,3 @@ func (c *clientFile) Flush() error { return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{}) } - -// Renamed implements File.Renamed. -func (c *clientFile) Renamed(newDir File, newName string) {} diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index c757583e0..b78fdab7a 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -62,6 +62,8 @@ func TestVersion(t *testing.T) { } func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + b.ReportAllocs() + // See above. serverSocket, clientSocket, err := unet.SocketPair(false) if err != nil { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index cab35896f..c2e3a3f98 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -135,6 +135,14 @@ type File interface { // On the server, Close has no concurrency guarantee. Close() error + // SetAttrClose is the equivalent of calling SetAttr() followed by Close(). + // This can be used to set file times before closing the file in a single + // operation. + // + // On the server, SetAttr has a write concurrency guarantee. + // On the server, Close has no concurrency guarantee. + SetAttrClose(valid SetAttrMask, attr SetAttr) error + // Open must be called prior to using Read, Write or Readdir. Once Open // is called, some operations, such as Walk, will no longer work. // @@ -286,3 +294,19 @@ type DefaultWalkGetAttr struct{} func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS } + +// DisallowClientCalls panics if a client-only function is called. +type DisallowClientCalls struct{} + +// SetAttrClose implements File.SetAttrClose. +func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { + panic("SetAttrClose should not be called on the server") +} + +// DisallowServerCalls panics if a server-only function is called. +type DisallowServerCalls struct{} + +// Renamed implements File.Renamed. +func (*clientFile) Renamed(File, string) { + panic("Renamed should not be called on the client") +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 1db5797dd..abd237f46 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -123,6 +123,37 @@ func (t *Tclunk) handle(cs *connState) message { return &Rclunk{} } +func (t *Tsetattrclunk) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + setAttrErr := ref.safelyWrite(func() error { + // We don't allow setattr on files that have been deleted. + // This might be technically incorrect, as it's possible that + // there were multiple links and you can still change the + // corresponding inode information. + if ref.isDeleted() { + return syscall.EINVAL + } + + // Set the attributes. + return ref.file.SetAttr(t.Valid, t.SetAttr) + }) + + // Try to delete FID even in case of failure above. Since the state of the + // file is unknown to the caller, it will not attempt to close the file again. + if !cs.DeleteFID(t.FID) { + return newErr(syscall.EBADF) + } + if setAttrErr != nil { + return newErr(setAttrErr) + } + return &Rsetattrclunk{} +} + // handle implements handler.handle. func (t *Tremove) handle(cs *connState) message { ref, ok := cs.LookupFID(t.FID) diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 57b89ad7d..cf13cbb69 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -317,6 +317,64 @@ func (r *Rclunk) String() string { return "Rclunk{}" } +// Tsetattrclunk is a setattr+close request. +type Tsetattrclunk struct { + // FID is the FID to change. + FID FID + + // Valid is the set of bits which will be used. + Valid SetAttrMask + + // SetAttr is the set request. + SetAttr SetAttr +} + +// decode implements encoder.decode. +func (t *Tsetattrclunk) decode(b *buffer) { + t.FID = b.ReadFID() + t.Valid.decode(b) + t.SetAttr.decode(b) +} + +// encode implements encoder.encode. +func (t *Tsetattrclunk) encode(b *buffer) { + b.WriteFID(t.FID) + t.Valid.encode(b) + t.SetAttr.encode(b) +} + +// Type implements message.Type. +func (*Tsetattrclunk) Type() MsgType { + return MsgTsetattrclunk +} + +// String implements fmt.Stringer. +func (t *Tsetattrclunk) String() string { + return fmt.Sprintf("Tsetattrclunk{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr) +} + +// Rsetattrclunk is a setattr+close response. +type Rsetattrclunk struct { +} + +// decode implements encoder.decode. +func (*Rsetattrclunk) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rsetattrclunk) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rsetattrclunk) Type() MsgType { + return MsgRsetattrclunk +} + +// String implements fmt.Stringer. +func (r *Rsetattrclunk) String() string { + return "Rsetattrclunk{}" +} + // Tremove is a remove request. // // This will eventually be replaced by Tunlinkat. @@ -2506,7 +2564,7 @@ type msgFactory struct { var msgRegistry registry type registry struct { - factories [math.MaxUint8]msgFactory + factories [math.MaxUint8 + 1]msgFactory // largestFixedSize is computed so that given some message size M, you can // compute the maximum payload size (e.g. for Twrite, Rread) with @@ -2657,6 +2715,8 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) + msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go index 7facc9f5e..bfeb6c236 100644 --- a/pkg/p9/messages_test.go +++ b/pkg/p9/messages_test.go @@ -376,6 +376,30 @@ func TestEncodeDecode(t *testing.T) { &Rumknod{ Rmknod{QID: QID{Type: 1}}, }, + &Tsetattrclunk{ + FID: 1, + Valid: SetAttrMask{ + Permissions: true, + UID: true, + GID: true, + Size: true, + ATime: true, + MTime: true, + CTime: true, + ATimeNotSystemTime: true, + MTimeNotSystemTime: true, + }, + SetAttr: SetAttr{ + Permissions: 1, + UID: 2, + GID: 3, + Size: 4, + ATimeSeconds: 5, + ATimeNanoSeconds: 6, + MTimeSeconds: 7, + MTimeNanoSeconds: 8, + }, + }, } for _, enc := range objs { diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 122c457d2..2235f8968 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -315,86 +315,88 @@ type MsgType uint8 // MsgType declarations. const ( - MsgTlerror MsgType = 6 - MsgRlerror = 7 - MsgTstatfs = 8 - MsgRstatfs = 9 - MsgTlopen = 12 - MsgRlopen = 13 - MsgTlcreate = 14 - MsgRlcreate = 15 - MsgTsymlink = 16 - MsgRsymlink = 17 - MsgTmknod = 18 - MsgRmknod = 19 - MsgTrename = 20 - MsgRrename = 21 - MsgTreadlink = 22 - MsgRreadlink = 23 - MsgTgetattr = 24 - MsgRgetattr = 25 - MsgTsetattr = 26 - MsgRsetattr = 27 - MsgTlistxattr = 28 - MsgRlistxattr = 29 - MsgTxattrwalk = 30 - MsgRxattrwalk = 31 - MsgTxattrcreate = 32 - MsgRxattrcreate = 33 - MsgTgetxattr = 34 - MsgRgetxattr = 35 - MsgTsetxattr = 36 - MsgRsetxattr = 37 - MsgTremovexattr = 38 - MsgRremovexattr = 39 - MsgTreaddir = 40 - MsgRreaddir = 41 - MsgTfsync = 50 - MsgRfsync = 51 - MsgTlink = 70 - MsgRlink = 71 - MsgTmkdir = 72 - MsgRmkdir = 73 - MsgTrenameat = 74 - MsgRrenameat = 75 - MsgTunlinkat = 76 - MsgRunlinkat = 77 - MsgTversion = 100 - MsgRversion = 101 - MsgTauth = 102 - MsgRauth = 103 - MsgTattach = 104 - MsgRattach = 105 - MsgTflush = 108 - MsgRflush = 109 - MsgTwalk = 110 - MsgRwalk = 111 - MsgTread = 116 - MsgRread = 117 - MsgTwrite = 118 - MsgRwrite = 119 - MsgTclunk = 120 - MsgRclunk = 121 - MsgTremove = 122 - MsgRremove = 123 - MsgTflushf = 124 - MsgRflushf = 125 - MsgTwalkgetattr = 126 - MsgRwalkgetattr = 127 - MsgTucreate = 128 - MsgRucreate = 129 - MsgTumkdir = 130 - MsgRumkdir = 131 - MsgTumknod = 132 - MsgRumknod = 133 - MsgTusymlink = 134 - MsgRusymlink = 135 - MsgTlconnect = 136 - MsgRlconnect = 137 - MsgTallocate = 138 - MsgRallocate = 139 - MsgTchannel = 250 - MsgRchannel = 251 + MsgTlerror MsgType = 6 + MsgRlerror MsgType = 7 + MsgTstatfs MsgType = 8 + MsgRstatfs MsgType = 9 + MsgTlopen MsgType = 12 + MsgRlopen MsgType = 13 + MsgTlcreate MsgType = 14 + MsgRlcreate MsgType = 15 + MsgTsymlink MsgType = 16 + MsgRsymlink MsgType = 17 + MsgTmknod MsgType = 18 + MsgRmknod MsgType = 19 + MsgTrename MsgType = 20 + MsgRrename MsgType = 21 + MsgTreadlink MsgType = 22 + MsgRreadlink MsgType = 23 + MsgTgetattr MsgType = 24 + MsgRgetattr MsgType = 25 + MsgTsetattr MsgType = 26 + MsgRsetattr MsgType = 27 + MsgTlistxattr MsgType = 28 + MsgRlistxattr MsgType = 29 + MsgTxattrwalk MsgType = 30 + MsgRxattrwalk MsgType = 31 + MsgTxattrcreate MsgType = 32 + MsgRxattrcreate MsgType = 33 + MsgTgetxattr MsgType = 34 + MsgRgetxattr MsgType = 35 + MsgTsetxattr MsgType = 36 + MsgRsetxattr MsgType = 37 + MsgTremovexattr MsgType = 38 + MsgRremovexattr MsgType = 39 + MsgTreaddir MsgType = 40 + MsgRreaddir MsgType = 41 + MsgTfsync MsgType = 50 + MsgRfsync MsgType = 51 + MsgTlink MsgType = 70 + MsgRlink MsgType = 71 + MsgTmkdir MsgType = 72 + MsgRmkdir MsgType = 73 + MsgTrenameat MsgType = 74 + MsgRrenameat MsgType = 75 + MsgTunlinkat MsgType = 76 + MsgRunlinkat MsgType = 77 + MsgTversion MsgType = 100 + MsgRversion MsgType = 101 + MsgTauth MsgType = 102 + MsgRauth MsgType = 103 + MsgTattach MsgType = 104 + MsgRattach MsgType = 105 + MsgTflush MsgType = 108 + MsgRflush MsgType = 109 + MsgTwalk MsgType = 110 + MsgRwalk MsgType = 111 + MsgTread MsgType = 116 + MsgRread MsgType = 117 + MsgTwrite MsgType = 118 + MsgRwrite MsgType = 119 + MsgTclunk MsgType = 120 + MsgRclunk MsgType = 121 + MsgTremove MsgType = 122 + MsgRremove MsgType = 123 + MsgTflushf MsgType = 124 + MsgRflushf MsgType = 125 + MsgTwalkgetattr MsgType = 126 + MsgRwalkgetattr MsgType = 127 + MsgTucreate MsgType = 128 + MsgRucreate MsgType = 129 + MsgTumkdir MsgType = 130 + MsgRumkdir MsgType = 131 + MsgTumknod MsgType = 132 + MsgRumknod MsgType = 133 + MsgTusymlink MsgType = 134 + MsgRusymlink MsgType = 135 + MsgTlconnect MsgType = 136 + MsgRlconnect MsgType = 137 + MsgTallocate MsgType = 138 + MsgRallocate MsgType = 139 + MsgTsetattrclunk MsgType = 140 + MsgRsetattrclunk MsgType = 141 + MsgTchannel MsgType = 250 + MsgRchannel MsgType = 251 ) // QIDType represents the file type for QIDs. diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e7bb3db2..6e605b14c 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -1225,22 +1225,31 @@ func TestOpen(t *testing.T) { func TestClose(t *testing.T) { type closeTest struct { name string - closeFn func(backend *Mock, f p9.File) + closeFn func(backend *Mock, f p9.File) error } cases := []closeTest{ { name: "close", - closeFn: func(_ *Mock, f p9.File) { - f.Close() + closeFn: func(_ *Mock, f p9.File) error { + return f.Close() }, }, { name: "remove", - closeFn: func(backend *Mock, f p9.File) { + closeFn: func(backend *Mock, f p9.File) error { // Allow the rename call in the parent, automatically translated. backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) - f.(deprecatedRemover).Remove() + return f.(deprecatedRemover).Remove() + }, + }, + { + name: "setAttrClose", + closeFn: func(backend *Mock, f p9.File) error { + valid := p9.SetAttrMask{ATime: true} + attr := p9.SetAttr{ATimeSeconds: 1, ATimeNanoSeconds: 2} + backend.EXPECT().SetAttr(valid, attr).Times(1) + return f.SetAttrClose(valid, attr) }, }, } @@ -1258,7 +1267,9 @@ func TestClose(t *testing.T) { _, backend, f := walkHelper(h, name, root) // Close via the prescribed method. - tc.closeFn(backend, f) + if err := tc.closeFn(backend, f); err != nil { + t.Fatalf("closeFn failed: %v", err) + } // Everything should fail with EBADF. if _, _, err := f.Walk(nil); err != syscall.EBADF { diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 60cf94fa1..3736f12a3 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -60,12 +60,6 @@ type connState struct { // server is the backing server. server *Server - // sendMu is the send lock. - sendMu sync.Mutex - - // conn is the connection. - conn *unet.Socket - // fids is the set of active FIDs. // // This is used to find FIDs for files. @@ -87,16 +81,30 @@ type connState struct { // version 0 implies 9P2000.L. version uint32 + // pendingWg counts requests that are still being handled. + pendingWg sync.WaitGroup + // -- below relates to the legacy handler -- - // recvOkay indicates that a receive may start. - recvOkay chan bool + // recvMu serializes receiving from conn. + recvMu sync.Mutex + + // recvIdle is the number of goroutines in handleRequests() attempting to + // lock recvMu so that they can receive from conn. recvIdle is accessed + // using atomic memory operations. + recvIdle int32 - // recvDone is signalled when a message is received. - recvDone chan error + // If recvShutdown is true, at least one goroutine has observed a + // connection error while receiving from conn, and all goroutines in + // handleRequests() should exit immediately. recvShutdown is protected by + // recvMu. + recvShutdown bool - // sendDone is signalled when a send is finished. - sendDone chan error + // sendMu serializes sending to conn. + sendMu sync.Mutex + + // conn is the connection used by the legacy transport. + conn *unet.Socket // -- below relates to the flipcall handler -- @@ -479,7 +487,9 @@ func (cs *connState) lookupChannel(id uint32) *channel { // handle handles a single message. func (cs *connState) handle(m message) (r message) { + cs.pendingWg.Add(1) defer func() { + cs.pendingWg.Done() if r == nil { // Don't allow a panic to propagate. err := recover() @@ -503,11 +513,21 @@ func (cs *connState) handle(m message) (r message) { return } -// handleRequest handles a single request. -// -// The recvDone channel is signaled when recv is done (with a error if -// necessary). The sendDone channel is signaled with the result of the send. -func (cs *connState) handleRequest() { +// handleRequest handles a single request. It returns true if the caller should +// continue handling requests and false if it should terminate. +func (cs *connState) handleRequest() bool { + // Obtain the right to receive a message from cs.conn. + atomic.AddInt32(&cs.recvIdle, 1) + cs.recvMu.Lock() + atomic.AddInt32(&cs.recvIdle, -1) + + if cs.recvShutdown { + // Another goroutine already detected a connection problem; exit + // immediately. + cs.recvMu.Unlock() + return false + } + messageSize := atomic.LoadUint32(&cs.messageSize) if messageSize == 0 { // Default or not yet negotiated. @@ -518,12 +538,17 @@ func (cs *connState) handleRequest() { tag, m, err := recv(cs.conn, messageSize, msgRegistry.get) if errSocket, ok := err.(ErrSocket); ok { // Connection problem; stop serving. - cs.recvDone <- errSocket.error - return + log.Debugf("p9.recv: %v", errSocket.error) + cs.recvShutdown = true + cs.recvMu.Unlock() + return false } - // Signal receive is done. - cs.recvDone <- nil + // Ensure that another goroutine is available to receive from cs.conn. + if atomic.LoadInt32(&cs.recvIdle) == 0 { + go cs.handleRequests() // S/R-SAFE: Irrelevant. + } + cs.recvMu.Unlock() // Deal with other errors. if err != nil && err != io.EOF { @@ -532,16 +557,17 @@ func (cs *connState) handleRequest() { cs.sendMu.Lock() err := send(cs.conn, tag, newErr(err)) cs.sendMu.Unlock() - cs.sendDone <- err - return + if err != nil { + log.Debugf("p9.send: %v", err) + } + return true } // Try to start the tag. if !cs.StartTag(tag) { // Nothing we can do at this point; client is bogus. log.Debugf("no valid tag [%05d]", tag) - cs.sendDone <- ErrNoValidMessage - return + return true } // Handle the message. @@ -555,23 +581,29 @@ func (cs *connState) handleRequest() { cs.sendMu.Lock() err = send(cs.conn, tag, r) cs.sendMu.Unlock() - cs.sendDone <- err + if err != nil { + log.Debugf("p9.send: %v", err) + } // Return the message to the cache. msgRegistry.put(m) + + return true } func (cs *connState) handleRequests() { - for range cs.recvOkay { - cs.handleRequest() + for { + if !cs.handleRequest() { + return + } } } func (cs *connState) stop() { - // Close all channels. - close(cs.recvOkay) - close(cs.recvDone) - close(cs.sendDone) + // Wait for completion of all inflight requests. This is mostly so that if + // a request is stuck, the sandbox supervisor has the opportunity to kill + // us with SIGABRT to get a stack dump of the offending handler. + cs.pendingWg.Wait() // Free the channels. cs.channelMu.Lock() @@ -590,6 +622,9 @@ func (cs *connState) stop() { cs.channelAlloc.Destroy() } + // Ensure the connection is closed. + cs.conn.Close() + // Close all remaining fids. for fid, fidRef := range cs.fids { delete(cs.fids, fid) @@ -599,74 +634,23 @@ func (cs *connState) stop() { // handlers running via the wait for Pending => 0 below. fidRef.DecRef() } - - // Ensure the connection is closed. - cs.conn.Close() -} - -// service services requests concurrently. -func (cs *connState) service() error { - // Pending is the number of handlers that have finished receiving but - // not finished processing requests. These must be waiting on properly - // below. See the next comment for an explanation of the loop. - pending := 0 - - // Start the first request handler. - go cs.handleRequests() // S/R-SAFE: Irrelevant. - cs.recvOkay <- true - - // We loop and make sure there's always one goroutine waiting for a new - // request. We process all the data for a single request in one - // goroutine however, to ensure the best turnaround time possible. - for { - select { - case err := <-cs.recvDone: - if err != nil { - // Wait for pending handlers. - for i := 0; i < pending; i++ { - <-cs.sendDone - } - return nil - } - - // This handler is now pending. - pending++ - - // Kick the next receiver, or start a new handler - // if no receiver is currently waiting. - select { - case cs.recvOkay <- true: - default: - go cs.handleRequests() // S/R-SAFE: Irrelevant. - cs.recvOkay <- true - } - - case <-cs.sendDone: - // This handler is finished. - pending-- - - // Error sending a response? Nothing can be done. - // - // We don't terminate on a send error though, since - // we still have a pending receive. The error would - // have been logged above, we just ignore it here. - } - } } // Handle handles a single connection. func (s *Server) Handle(conn *unet.Socket) error { cs := &connState{ - server: s, - conn: conn, - fids: make(map[FID]*fidRef), - tags: make(map[Tag]chan struct{}), - recvOkay: make(chan bool), - recvDone: make(chan error, 10), - sendDone: make(chan error, 10), + server: s, + fids: make(map[FID]*fidRef), + tags: make(map[Tag]chan struct{}), + conn: conn, } defer cs.stop() - return cs.service() + + // Serve requests from conn in the current goroutine; handleRequests() will + // create more goroutines as needed. + cs.handleRequests() + + return nil } // Serve handles requests from the bound socket. diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 7cec0e86d..02e665345 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -66,14 +66,17 @@ const ( var dataPool = sync.Pool{ New: func() interface{} { // These buffers are used for decoding without a payload. - return make([]byte, initialBufferLength) + // We need to return a pointer to avoid unnecessary allocations + // (see https://staticcheck.io/docs/checks#SA6002). + b := make([]byte, initialBufferLength) + return &b }, } // send sends the given message over the socket. func send(s *unet.Socket, tag Tag, m message) error { - data := dataPool.Get().([]byte) - dataBuf := buffer{data: data[:0]} + data := dataPool.Get().(*[]byte) + dataBuf := buffer{data: (*data)[:0]} if log.IsLogging(log.Debug) { log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) @@ -141,7 +144,7 @@ func send(s *unet.Socket, tag Tag, m message) error { } // All set. - dataPool.Put(dataBuf.data) + dataPool.Put(&dataBuf.data) return nil } @@ -227,12 +230,29 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, // Not yet initialized. var dataBuf buffer + var vecs [][]byte + + appendBuffer := func(size int) *[]byte { + // Pull a data buffer from the pool. + datap := dataPool.Get().(*[]byte) + data := *datap + if size > len(data) { + // Create a larger data buffer. + data = make([]byte, size) + datap = &data + } else { + // Limit the data buffer. + data = data[:size] + } + dataBuf = buffer{data: data} + vecs = append(vecs, data) + return datap + } // Read the rest of the payload. // // This requires some special care to ensure that the vectors all line // up the way they should. We do this to minimize copying data around. - var vecs [][]byte if payloader, ok := m.(payloader); ok { fixedSize := payloader.FixedSize() @@ -246,22 +266,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, } if fixedSize != 0 { - // Pull a data buffer from the pool. - data := dataPool.Get().([]byte) - if int(fixedSize) > len(data) { - // Create a larger data buffer, ensuring - // sufficient capicity for the message. - data = make([]byte, fixedSize) - defer dataPool.Put(data) - dataBuf = buffer{data: data} - vecs = append(vecs, data) - } else { - // Limit the data buffer, and make sure it - // gets filled before the payload buffer. - defer dataPool.Put(data) - dataBuf = buffer{data: data[:fixedSize]} - vecs = append(vecs, data[:fixedSize]) - } + datap := appendBuffer(int(fixedSize)) + defer dataPool.Put(datap) } // Include the payload. @@ -274,20 +280,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, vecs = append(vecs, p) } } else if remaining != 0 { - // Pull a data buffer from the pool. - data := dataPool.Get().([]byte) - if int(remaining) > len(data) { - // Create a larger data buffer. - data = make([]byte, remaining) - defer dataPool.Put(data) - dataBuf = buffer{data: data} - vecs = append(vecs, data) - } else { - // Limit the data buffer. - defer dataPool.Put(data) - dataBuf = buffer{data: data[:remaining]} - vecs = append(vecs, data[:remaining]) - } + datap := appendBuffer(int(remaining)) + defer dataPool.Put(datap) } if len(vecs) > 0 { diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index 3668fcad7..e7406b374 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -182,6 +182,8 @@ func TestSendClosed(t *testing.T) { } func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + server, client, err := unet.SocketPair(false) if err != nil { b.Fatalf("socketpair got err %v expected nil", err) diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 09cde9f5a..8d7168ef5 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 11 + highestSupportedVersion uint32 = 12 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -173,3 +173,9 @@ func versionSupportsGetSetXattr(v uint32) bool { func versionSupportsListRemoveXattr(v uint32) bool { return v >= 11 } + +// versionSupportsTsetattrclunk returns true if version v supports +// the Tsetattrclunk message. +func versionSupportsTsetattrclunk(v uint32) bool { + return v >= 12 +} diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 7c622e5d7..a45920040 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.16 +// +build !go1.17 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index 48ebb5fd1..9d3b0666d 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.16 +// +build !go1.17 #include "textflag.h" diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 74affc887..9888cce9c 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -24,6 +24,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/context", "//pkg/log", "//pkg/sync", ], @@ -34,5 +35,8 @@ go_test( size = "small", srcs = ["refcounter_test.go"], library = ":refs", - deps = ["//pkg/sync"], + deps = [ + "//pkg/context", + "//pkg/sync", + ], ) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index c45ba8200..699ea8ac3 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -23,6 +23,7 @@ import ( "runtime" "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" ) @@ -38,7 +39,7 @@ type RefCounter interface { // Note that AtomicRefCounter.DecRef() does not support destructors. // If a type has a destructor, it must implement its own DecRef() // method and call AtomicRefCounter.DecRefWithDestructor(destructor). - DecRef() + DecRef(ctx context.Context) // TryIncRef attempts to increase the reference counter on the object, // but may fail if all references have already been dropped. This @@ -57,7 +58,7 @@ type RefCounter interface { // A WeakRefUser is notified when the last non-weak reference is dropped. type WeakRefUser interface { // WeakRefGone is called when the last non-weak reference is dropped. - WeakRefGone() + WeakRefGone(ctx context.Context) } // WeakRef is a weak reference. @@ -123,7 +124,7 @@ func (w *WeakRef) Get() RefCounter { // Drop drops this weak reference. You should always call drop when you are // finished with the weak reference. You may not use this object after calling // drop. -func (w *WeakRef) Drop() { +func (w *WeakRef) Drop(ctx context.Context) { rc, ok := w.get() if !ok { // We've been zapped already. When the refcounter has called @@ -145,7 +146,7 @@ func (w *WeakRef) Drop() { // And now aren't on the object's list of weak references. So it won't // zap us if this causes the reference count to drop to zero. - rc.DecRef() + rc.DecRef(ctx) // Return to the pool. weakRefPool.Put(w) @@ -214,6 +215,8 @@ type AtomicRefCount struct { // LeakMode configures the leak checker. type LeakMode uint32 +// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref +// counting is gone. const ( // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. UninitializedLeakChecking LeakMode = iota @@ -231,6 +234,41 @@ const ( LeaksLogTraces ) +// Set implements flag.Value. +func (l *LeakMode) Set(v string) error { + switch v { + case "disabled": + *l = NoLeakChecking + case "log-names": + *l = LeaksLogWarning + case "log-traces": + *l = LeaksLogTraces + default: + return fmt.Errorf("invalid ref leak mode %q", v) + } + return nil +} + +// Get implements flag.Value. +func (l *LeakMode) Get() interface{} { + return *l +} + +// String implements flag.Value. +func (l *LeakMode) String() string { + switch *l { + case UninitializedLeakChecking: + return "uninitialized" + case NoLeakChecking: + return "disabled" + case LeaksLogWarning: + return "log-names" + case LeaksLogTraces: + return "log-traces" + } + panic(fmt.Sprintf("invalid ref leak mode %d", *l)) +} + // leakMode stores the current mode for the reference leak checker. // // Values must be one of the LeakMode values. @@ -243,6 +281,11 @@ func SetLeakMode(mode LeakMode) { atomic.StoreUint32(&leakMode, uint32(mode)) } +// GetLeakMode returns the current leak mode. +func GetLeakMode() LeakMode { + return LeakMode(atomic.LoadUint32(&leakMode)) +} + const maxStackFrames = 40 type fileLine struct { @@ -427,7 +470,7 @@ func (r *AtomicRefCount) dropWeakRef(w *WeakRef) { // A: TryIncRef [transform speculative to real] // //go:nosplit -func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { +func (r *AtomicRefCount) DecRefWithDestructor(ctx context.Context, destroy func(context.Context)) { switch v := atomic.AddInt64(&r.refCount, -1); { case v < -1: panic("Decrementing non-positive ref count") @@ -448,7 +491,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { if user != nil { r.mu.Unlock() - user.WeakRefGone() + user.WeakRefGone(ctx) r.mu.Lock() } } @@ -456,7 +499,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { // Call the destructor. if destroy != nil { - destroy() + destroy(ctx) } } } @@ -464,6 +507,16 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { // DecRef decrements this object's reference count. // //go:nosplit -func (r *AtomicRefCount) DecRef() { - r.DecRefWithDestructor(nil) +func (r *AtomicRefCount) DecRef(ctx context.Context) { + r.DecRefWithDestructor(ctx, nil) +} + +// OnExit is called on sandbox exit. It runs GC to enqueue refcount finalizers, +// which check for reference leaks. There is no way to guarantee that every +// finalizer will run before exiting, but this at least ensures that they will +// be discovered/enqueued by GC. +func OnExit() { + if LeakMode(atomic.LoadUint32(&leakMode)) != NoLeakChecking { + runtime.GC() + } } diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go index 1ab4a4440..6d0dd1018 100644 --- a/pkg/refs/refcounter_test.go +++ b/pkg/refs/refcounter_test.go @@ -18,6 +18,7 @@ import ( "reflect" "testing" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -31,11 +32,11 @@ type testCounter struct { destroyed bool } -func (t *testCounter) DecRef() { - t.AtomicRefCount.DecRefWithDestructor(t.destroy) +func (t *testCounter) DecRef(ctx context.Context) { + t.AtomicRefCount.DecRefWithDestructor(ctx, t.destroy) } -func (t *testCounter) destroy() { +func (t *testCounter) destroy(context.Context) { t.mu.Lock() defer t.mu.Unlock() t.destroyed = true @@ -53,7 +54,7 @@ func newTestCounter() *testCounter { func TestOneRef(t *testing.T) { tc := newTestCounter() - tc.DecRef() + tc.DecRef(context.Background()) if !tc.IsDestroyed() { t.Errorf("object should have been destroyed") @@ -63,8 +64,9 @@ func TestOneRef(t *testing.T) { func TestTwoRefs(t *testing.T) { tc := newTestCounter() tc.IncRef() - tc.DecRef() - tc.DecRef() + ctx := context.Background() + tc.DecRef(ctx) + tc.DecRef(ctx) if !tc.IsDestroyed() { t.Errorf("object should have been destroyed") @@ -74,12 +76,13 @@ func TestTwoRefs(t *testing.T) { func TestMultiRefs(t *testing.T) { tc := newTestCounter() tc.IncRef() - tc.DecRef() + ctx := context.Background() + tc.DecRef(ctx) tc.IncRef() - tc.DecRef() + tc.DecRef(ctx) - tc.DecRef() + tc.DecRef(ctx) if !tc.IsDestroyed() { t.Errorf("object should have been destroyed") @@ -89,19 +92,20 @@ func TestMultiRefs(t *testing.T) { func TestWeakRef(t *testing.T) { tc := newTestCounter() w := NewWeakRef(tc, nil) + ctx := context.Background() // Try resolving. if x := w.Get(); x == nil { t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) } else { - x.DecRef() + x.DecRef(ctx) } // Try resolving again. if x := w.Get(); x == nil { t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) } else { - x.DecRef() + x.DecRef(ctx) } // Shouldn't be destroyed yet. (Can't continue if this fails.) @@ -110,7 +114,7 @@ func TestWeakRef(t *testing.T) { } // Drop the original reference. - tc.DecRef() + tc.DecRef(ctx) // Assert destroyed. if !tc.IsDestroyed() { @@ -126,7 +130,8 @@ func TestWeakRef(t *testing.T) { func TestWeakRefDrop(t *testing.T) { tc := newTestCounter() w := NewWeakRef(tc, nil) - w.Drop() + ctx := context.Background() + w.Drop(ctx) // Just assert the list is empty. if !tc.weakRefs.Empty() { @@ -134,14 +139,14 @@ func TestWeakRefDrop(t *testing.T) { } // Drop the original reference. - tc.DecRef() + tc.DecRef(ctx) } type testWeakRefUser struct { weakRefGone func() } -func (u *testWeakRefUser) WeakRefGone() { +func (u *testWeakRefUser) WeakRefGone(ctx context.Context) { u.weakRefGone() } @@ -165,7 +170,8 @@ func TestCallback(t *testing.T) { }}) // Drop the original reference, this must trigger the callback. - tc.DecRef() + ctx := context.Background() + tc.DecRef(ctx) if !called { t.Fatalf("Callback not called") diff --git a/pkg/refs_vfs2/BUILD b/pkg/refs_vfs2/BUILD new file mode 100644 index 000000000..577b827a5 --- /dev/null +++ b/pkg/refs_vfs2/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template") + +package(licenses = ["notice"]) + +go_template( + name = "refs_template", + srcs = [ + "refs_template.go", + ], + types = [ + "T", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//pkg/refs", + ], +) + +go_library( + name = "refs_vfs2", + srcs = ["refs.go"], + visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/context"], +) diff --git a/pkg/refs_vfs2/refs.go b/pkg/refs_vfs2/refs.go new file mode 100644 index 000000000..99a074e96 --- /dev/null +++ b/pkg/refs_vfs2/refs.go @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package refs_vfs2 defines an interface for a reference-counted object. +package refs_vfs2 + +import ( + "gvisor.dev/gvisor/pkg/context" +) + +// RefCounter is the interface to be implemented by objects that are reference +// counted. +type RefCounter interface { + // IncRef increments the reference counter on the object. + IncRef() + + // DecRef decrements the object's reference count. Users of refs_template.Refs + // may specify a destructor to be called once the reference count reaches zero. + DecRef(ctx context.Context) + + // TryIncRef attempts to increment the reference count, but may fail if all + // references have already been dropped, in which case it returns false. If + // true is returned, then a valid reference is now held on the object. + TryIncRef() bool +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refs_vfs2/refs_template.go new file mode 100644 index 000000000..d9b552896 --- /dev/null +++ b/pkg/refs_vfs2/refs_template.go @@ -0,0 +1,142 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package refs_template defines a template that can be used by reference +// counted objects. The "owner" template parameter is used in log messages to +// indicate the type of reference-counted object that exhibited a reference +// leak. As a result, structs that are embedded in other structs should not use +// this template, since it will make tracking down leaks more difficult. +package refs_template + +import ( + "fmt" + "runtime" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" +) + +// T is the type of the reference counted object. It is only used to customize +// debug output when leak checking. +type T interface{} + +// ownerType is used to customize logging. Note that we use a pointer to T so +// that we do not copy the entire object when passed as a format parameter. +var ownerType *T + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// Note that the number of references is actually refCount + 1 so that a default +// zero-value Refs object contains one reference. +// +// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in +// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. +// This will allow us to add stack trace information to the leak messages +// without growing the size of Refs. +// +// +stateify savable +type Refs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +func (r *Refs) finalize() { + var note string + switch refs_vfs1.GetLeakMode() { + case refs_vfs1.NoLeakChecking: + return + case refs_vfs1.UninitializedLeakChecking: + note = "(Leak checker uninitialized): " + } + if n := r.ReadRefs(); n != 0 { + log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) + } +} + +// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. +func (r *Refs) EnableLeakCheck() { + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + runtime.SetFinalizer(r, (*Refs).finalize) + } +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *Refs) ReadRefs() int64 { + // Account for the internal -1 offset on refcounts. + return atomic.LoadInt64(&r.refCount) + 1 +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *Refs) IncRef() { + if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { + panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *Refs) TryIncRef() bool { + const speculativeRef = 1 << 32 + v := atomic.AddInt64(&r.refCount, speculativeRef) + if int32(v) < 0 { + // This object has already been freed. + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + // Turn into a real reference. + atomic.AddInt64(&r.refCount, -speculativeRef+1) + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *Refs) DecRef(destroy func()) { + switch v := atomic.AddInt64(&r.refCount, -1); { + case v < -1: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) + + case v == -1: + // Call the destructor. + if destroy != nil { + destroy() + } + } +} diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD index ce30382ab..68ed074f8 100644 --- a/pkg/safemem/BUILD +++ b/pkg/safemem/BUILD @@ -11,9 +11,7 @@ go_library( "seq_unsafe.go", ], visibility = ["//:sandbox"], - deps = [ - "//pkg/safecopy", - ], + deps = ["//pkg/safecopy"], ) go_test( diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go index f5f0574f8..fc4049eeb 100644 --- a/pkg/safemem/seq_unsafe.go +++ b/pkg/safemem/seq_unsafe.go @@ -91,9 +91,10 @@ func BlockSeqFromSlice(slice []Block) BlockSeq { return blockSeqFromSliceLimited(slice, limit) } -// Preconditions: The combined length of all Blocks in slice <= limit. If -// len(slice) != 0, the first Block in slice has non-zero length, and limit > -// 0. +// Preconditions: +// * The combined length of all Blocks in slice <= limit. +// * If len(slice) != 0, the first Block in slice has non-zero length and +// limit > 0. func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq { switch len(slice) { case 0: diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index c5fca2ba3..e828894b0 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -5,7 +5,12 @@ package(licenses = ["notice"]) go_binary( name = "victim", testonly = 1, - srcs = ["seccomp_test_victim.go"], + srcs = [ + "seccomp_test_victim.go", + "seccomp_test_victim_amd64.go", + "seccomp_test_victim_arm64.go", + ], + nogo = False, deps = [":seccomp"], ) @@ -44,7 +49,7 @@ go_test( library = ":seccomp", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/bpf", + "//pkg/usermem", ], ) diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 55fd6967e..752e2dc32 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package seccomp provides basic seccomp filters for x86_64 (little endian). +// Package seccomp provides generation of basic seccomp filters. Currently, +// only little endian systems are supported. package seccomp import ( @@ -64,9 +65,9 @@ func Install(rules SyscallRules) error { Rules: rules, Action: linux.SECCOMP_RET_ALLOW, }, - }, defaultAction) + }, defaultAction, defaultAction) if log.IsLogging(log.Debug) { - programStr, errDecode := bpf.DecodeProgram(instrs) + programStr, errDecode := bpf.DecodeInstructions(instrs) if errDecode != nil { programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr) } @@ -117,7 +118,7 @@ var SyscallName = func(sysno uintptr) string { // BuildProgram builds a BPF program from the given map of actions to matching // SyscallRules. The single generated program covers all provided RuleSets. -func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) { +func BuildProgram(rules []RuleSet, defaultAction, badArchAction linux.BPFAction) ([]linux.BPFInstruction, error) { program := bpf.NewProgramBuilder() // Be paranoid and check that syscall is done in the expected architecture. @@ -128,7 +129,7 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn // defaultLabel is at the bottom of the program. The size of program // may exceeds 255 lines, which is the limit of a condition jump. program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0) - program.AddDirectJumpLabel(defaultLabel) + program.AddStmt(bpf.Ret|bpf.K, uint32(badArchAction)) if err := buildIndex(rules, program); err != nil { return nil, err } @@ -144,6 +145,11 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn // buildIndex builds a BST to quickly search through all syscalls. func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error { + // Do nothing if rules is empty. + if len(rules) == 0 { + return nil + } + // Build a list of all application system calls, across all given rule // sets. We have a simple BST, but may dispatch individual matchers // with different actions. The matchers are evaluated linearly. @@ -216,42 +222,163 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc labelled := false for i, arg := range rule { if arg != nil { + // Break out early if using MatchAny since no further + // instructions are required. + if _, ok := arg.(MatchAny); ok { + continue + } + + // Determine the data offset for low and high bits of input. + dataOffsetLow := seccompDataOffsetArgLow(i) + dataOffsetHigh := seccompDataOffsetArgHigh(i) + if i == RuleIP { + dataOffsetLow = seccompDataOffsetIPLow + dataOffsetHigh = seccompDataOffsetIPHigh + } + + // Add the conditional operation. Input values to the BPF + // program are 64bit values. However, comparisons in BPF can + // only be done on 32bit values. This means that we need to do + // multiple BPF comparisons in order to do one logical 64bit + // comparison. switch a := arg.(type) { - case AllowAny: - case AllowValue: - dataOffsetLow := seccompDataOffsetArgLow(i) - dataOffsetHigh := seccompDataOffsetArgHigh(i) - if i == RuleIP { - dataOffsetLow = seccompDataOffsetIPLow - dataOffsetHigh = seccompDataOffsetIPHigh - } + case EqualTo: + // EqualTo checks that both the higher and lower 32bits are equal. high, low := uint32(a>>32), uint32(a) - // assert arg_low == low + + // Assert that the lower 32bits are equal. + // arg_low == low ? continue : violation p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) - // assert arg_high == high + + // Assert that the lower 32bits are also equal. + // arg_high == high ? continue/success : violation p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) labelled = true + case NotEqual: + // NotEqual checks that either the higher or lower 32bits + // are *not* equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("ne%v", i) + + // Check if the higher 32bits are (not) equal. + // arg_low == low ? continue : success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are not equal (assuming + // higher bits are equal). + // arg_high == high ? violation : continue/success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true case GreaterThan: - dataOffsetLow := seccompDataOffsetArgLow(i) - dataOffsetHigh := seccompDataOffsetArgHigh(i) - if i == RuleIP { - dataOffsetLow = seccompDataOffsetIPLow - dataOffsetHigh = seccompDataOffsetIPHigh - } - labelGood := fmt.Sprintf("gt%v", i) + // GreaterThan checks that the higher 32bits is greater + // *or* that the higher 32bits are equal and the lower + // 32bits are greater. high, low := uint32(a>>32), uint32(a) - // assert arg_high < high + labelGood := fmt.Sprintf("gt%v", i) + + // Assert the higher 32bits are greater than or equal. + // arg_high >= high ? continue : violation (arg_high < high) p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) - // arg_high > high + + // Assert that the lower 32bits are greater. + // arg_high == high ? continue : success (arg_high > high) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) - // arg_low < low + // arg_low > low ? continue/success : violation (arg_high == high and arg_low <= low) p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) labelled = true + case GreaterThanOrEqual: + // GreaterThanOrEqual checks that the higher 32bits is + // greater *or* that the higher 32bits are equal and the + // lower 32bits are greater than or equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("ge%v", i) + + // Assert the higher 32bits are greater than or equal. + // arg_high >= high ? continue : violation (arg_high < high) + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + // arg_high == high ? continue : success (arg_high > high) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are greater (assuming the + // higher bits are equal). + // arg_low >= low ? continue/success : violation (arg_high == high and arg_low < low) + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case LessThan: + // LessThan checks that the higher 32bits is less *or* that + // the higher 32bits are equal and the lower 32bits are + // less. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("lt%v", i) + + // Assert the higher 32bits are less than or equal. + // arg_high > high ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + // arg_high == high ? continue : success (arg_high < high) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are less (assuming the + // higher bits are equal). + // arg_low >= low ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jge|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case LessThanOrEqual: + // LessThan checks that the higher 32bits is less *or* that + // the higher 32bits are equal and the lower 32bits are + // less than or equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("le%v", i) + + // Assert the higher 32bits are less than or equal. + // assert arg_high > high ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + // arg_high == high ? continue : success + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert the lower bits are less than or equal (assuming + // the higher bits are equal). + // arg_low > low ? violation : success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case maskedEqual: + // MaskedEqual checks that the bitwise AND of the value and + // mask are equal for both the higher and lower 32bits. + high, low := uint32(a.value>>32), uint32(a.value) + maskHigh, maskLow := uint32(a.mask>>32), uint32(a.mask) + + // Assert that the lower 32bits are equal when masked. + // A <- arg_low. + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + // A <- arg_low & maskLow + p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskLow) + // Assert that arg_low & maskLow == low. + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + + // Assert that the higher 32bits are equal when masked. + // A <- arg_high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + // A <- arg_high & maskHigh + p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskHigh) + // Assert that arg_high & maskHigh == high. + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + labelled = true default: return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) } diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index a52dc1b4e..daf165bbf 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -39,28 +39,79 @@ func seccompDataOffsetArgHigh(i int) uint32 { return seccompDataOffsetArgLow(i) + 4 } -// AllowAny is marker to indicate any value will be accepted. -type AllowAny struct{} +// MatchAny is marker to indicate any value will be accepted. +type MatchAny struct{} -func (a AllowAny) String() (s string) { +func (a MatchAny) String() (s string) { return "*" } -// AllowValue specifies a value that needs to be strictly matched. -type AllowValue uintptr +// EqualTo specifies a value that needs to be strictly matched. +type EqualTo uintptr + +func (a EqualTo) String() (s string) { + return fmt.Sprintf("== %#x", uintptr(a)) +} + +// NotEqual specifies a value that is strictly not equal. +type NotEqual uintptr + +func (a NotEqual) String() (s string) { + return fmt.Sprintf("!= %#x", uintptr(a)) +} // GreaterThan specifies a value that needs to be strictly smaller. type GreaterThan uintptr -func (a AllowValue) String() (s string) { - return fmt.Sprintf("%#x ", uintptr(a)) +func (a GreaterThan) String() (s string) { + return fmt.Sprintf("> %#x", uintptr(a)) +} + +// GreaterThanOrEqual specifies a value that needs to be smaller or equal. +type GreaterThanOrEqual uintptr + +func (a GreaterThanOrEqual) String() (s string) { + return fmt.Sprintf(">= %#x", uintptr(a)) +} + +// LessThan specifies a value that needs to be strictly greater. +type LessThan uintptr + +func (a LessThan) String() (s string) { + return fmt.Sprintf("< %#x", uintptr(a)) +} + +// LessThanOrEqual specifies a value that needs to be greater or equal. +type LessThanOrEqual uintptr + +func (a LessThanOrEqual) String() (s string) { + return fmt.Sprintf("<= %#x", uintptr(a)) +} + +type maskedEqual struct { + mask uintptr + value uintptr +} + +func (a maskedEqual) String() (s string) { + return fmt.Sprintf("& %#x == %#x", a.mask, a.value) +} + +// MaskedEqual specifies a value that matches the input after the input is +// masked (bitwise &) against the given mask. Can be used to verify that input +// only includes certain approved flags. +func MaskedEqual(mask, value uintptr) interface{} { + return maskedEqual{ + mask: mask, + value: value, + } } // Rule stores the allowed syscall arguments. // // For example: // rule := Rule { -// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0 +// EqualTo(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0 // } type Rule [7]interface{} // 6 arguments + RIP @@ -89,12 +140,12 @@ func (r Rule) String() (s string) { // rules := SyscallRules{ // syscall.SYS_FUTEX: []Rule{ // { -// AllowAny{}, -// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), +// MatchAny{}, +// EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), // }, // OR // { -// AllowAny{}, -// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), +// MatchAny{}, +// EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), // }, // }, // syscall.SYS_GETPID: []Rule{}, diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 88766f33b..e1444d18b 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -28,17 +28,10 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/usermem" ) -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - // newVictim makes a victim binary. func newVictim() (string, error) { f, err := ioutil.TempFile("", "victim") @@ -58,9 +51,14 @@ func newVictim() (string, error) { return path, nil } -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +// dataAsInput converts a linux.SeccompData to a bpf.Input. +func dataAsInput(d *linux.SeccompData) bpf.Input { + buf := make([]byte, d.SizeBytes()) + d.MarshalUnsafe(buf) + return bpf.InputBytes{ + Data: buf, + Order: usermem.ByteOrder, + } } func TestBasic(t *testing.T) { @@ -69,18 +67,21 @@ func TestBasic(t *testing.T) { desc string // data is the input data. - data seccompData + data linux.SeccompData // want is the expected return value of the BPF program. want linux.BPFAction } for _, test := range []struct { + name string ruleSets []RuleSet defaultAction linux.BPFAction + badArchAction linux.BPFAction specs []spec }{ { + name: "Single syscall", ruleSets: []RuleSet{ { Rules: SyscallRules{1: {}}, @@ -88,26 +89,28 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Single syscall allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + desc: "syscall allowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Single syscall disallowed", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, + desc: "syscall disallowed", + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Multiple rulesets", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0x1), + EqualTo(0x1), }, }, }, @@ -122,30 +125,32 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_KILL_THREAD, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Multiple rulesets allowed (1a)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x1}}, + desc: "allowed (1a)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple rulesets allowed (1b)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + desc: "allowed (1b)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + desc: "syscall 1 matched 2nd rule", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, + desc: "no match", + data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, }, { + name: "Multiple syscalls", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -157,50 +162,52 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Multiple syscalls allowed (1)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, + desc: "allowed (1)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls allowed (3)", - data: seccompData{nr: 3, arch: linux.AUDIT_ARCH_X86_64}, + desc: "allowed (3)", + data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls allowed (5)", - data: seccompData{nr: 5, arch: linux.AUDIT_ARCH_X86_64}, + desc: "allowed (5)", + data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls disallowed (0)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, + desc: "disallowed (0)", + data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (2)", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, + desc: "disallowed (2)", + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (4)", - data: seccompData{nr: 4, arch: linux.AUDIT_ARCH_X86_64}, + desc: "disallowed (4)", + data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (6)", - data: seccompData{nr: 6, arch: linux.AUDIT_ARCH_X86_64}, + desc: "disallowed (6)", + data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (100)", - data: seccompData{nr: 100, arch: linux.AUDIT_ARCH_X86_64}, + desc: "disallowed (100)", + data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Wrong architecture", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -210,15 +217,17 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Wrong architecture", - data: seccompData{nr: 1, arch: 123}, - want: linux.SECCOMP_RET_TRAP, + desc: "arch (123)", + data: linux.SeccompData{Nr: 1, Arch: 123}, + want: linux.SECCOMP_RET_KILL_THREAD, }, }, }, { + name: "Syscall disallowed", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -228,22 +237,24 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall disallowed, action trap", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, + desc: "action trap", + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Syscall arguments", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowAny{}, - AllowValue(0xf), + MatchAny{}, + EqualTo(0xf), }, }, }, @@ -251,29 +262,31 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xf}}, + desc: "allowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Syscall argument disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xe}}, + desc: "disallowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Multiple arguments", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0xf), + EqualTo(0xf), }, { - AllowValue(0xe), + EqualTo(0xe), }, }, }, @@ -281,28 +294,30 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf}}, + desc: "match first rule", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xe}}, + desc: "match 2nd rule", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}}, want: linux.SECCOMP_RET_ALLOW, }, }, }, { + name: "EqualTo", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0), - AllowValue(math.MaxUint64 - 1), - AllowValue(math.MaxUint32), + EqualTo(0), + EqualTo(math.MaxUint64 - 1), + EqualTo(math.MaxUint32), }, }, }, @@ -310,37 +325,135 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "64bit syscall argument allowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, + desc: "argument allowed (all match)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, }, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "64bit syscall argument disallowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, + desc: "argument disallowed (one mismatch)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, }, want: linux.SECCOMP_RET_TRAP, }, { - desc: "64bit syscall argument disallowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + desc: "argument disallowed (multiple mismatch)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "NotEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + NotEqual(0x7aabbccdd), + NotEqual(math.MaxUint64 - 1), + NotEqual(math.MaxUint32), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (one equal)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (all equal)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThan", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + GreaterThan(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThan (multi)", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -355,46 +468,410 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (first arg smaller)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg smaller)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThanOrEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + GreaterThanOrEqual(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThanOrEqual (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + GreaterThanOrEqual(0xf), + GreaterThanOrEqual(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed (both greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (first arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg smaller)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg allowed (second arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (second arg smaller)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg smaller)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "LessThan", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + LessThan(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "GreaterThan: Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}}, + desc: "high 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "GreaterThan: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}}, + desc: "high 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + }, + }, + { + name: "LessThan (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + LessThan(0x1), + LessThan(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Syscall argument disallowed (smaller)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}}, + desc: "arg disallowed (first arg greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "GreaterThan2: Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}}, + desc: "arg disallowed (second arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "LessThanOrEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + LessThanOrEqual(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "GreaterThan2: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}}, + desc: "high 32bits equal, low 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits less", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + }, + }, + + { + name: "LessThanOrEqual (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + LessThanOrEqual(0x1), + LessThanOrEqual(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (first arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg allowed (second arg equal)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (second arg greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg greater)", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "MaskedEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // x & 00000001 00000011 (0x103) == 00000000 00000001 (0x1) + // Input x must have lowest order bit set and + // must *not* have 8th or second lowest order bit set. + MaskedEqual(0x103, 0x1), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed (low order mandatory bit)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000001 + Args: [6]uint64{0x1}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (low order optional bit)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000101 + Args: [6]uint64{0x5}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (lowest order bit not set)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000010 + Args: [6]uint64{0x2}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second lowest order bit set)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000011 + Args: [6]uint64{0x3}, + }, want: linux.SECCOMP_RET_TRAP, }, { - desc: "GreaterThan2: Syscall argument disallowed (smaller)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}}, + desc: "arg disallowed (8th bit set)", + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000001 00000000 + Args: [6]uint64{0x100}, + }, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Instruction Pointer", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - RuleIP: AllowValue(0x7aabbccdd), + RuleIP: EqualTo(0x7aabbccdd), }, }, }, @@ -402,40 +879,42 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "IP: Syscall instruction pointer allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, + desc: "allowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "IP: Syscall instruction pointer disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x711223344}, + desc: "disallowed", + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344}, want: linux.SECCOMP_RET_TRAP, }, }, }, } { - instrs, err := BuildProgram(test.ruleSets, test.defaultAction) - if err != nil { - t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err) - continue - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err) - continue - } - for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) + t.Run(test.name, func(t *testing.T) { + instrs, err := BuildProgram(test.ruleSets, test.defaultAction, test.badArchAction) if err != nil { - t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err) - continue + t.Fatalf("BuildProgram() got error: %v", err) } - if got != uint32(spec.want) { - t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want) + p, err := bpf.Compile(instrs) + if err != nil { + t.Fatalf("bpf.Compile() got error: %v", err) } - } + for _, spec := range test.specs { + got, err := bpf.Exec(p, dataAsInput(&spec.data)) + if err != nil { + t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err) + } + if got != uint32(spec.want) { + // Include a decoded version of the program in output for debugging purposes. + decoded, _ := bpf.DecodeInstructions(instrs) + t.Fatalf("%s: got: %d, want: %d\nBPF Program\n%s", spec.desc, got, spec.want, decoded) + } + } + }) } } @@ -457,7 +936,7 @@ func TestRandom(t *testing.T) { Rules: syscallRules, Action: linux.SECCOMP_RET_ALLOW, }, - }, linux.SECCOMP_RET_TRAP) + }, linux.SECCOMP_RET_TRAP, linux.SECCOMP_RET_KILL_THREAD) if err != nil { t.Fatalf("buildProgram() got error: %v", err) } @@ -466,8 +945,8 @@ func TestRandom(t *testing.T) { t.Fatalf("bpf.Compile() got error: %v", err) } for i := uint32(0); i < 200; i++ { - data := seccompData{nr: i, arch: linux.AUDIT_ARCH_X86_64} - got, err := bpf.Exec(p, data.asInput()) + data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH} + got, err := bpf.Exec(p, dataAsInput(&data)) if err != nil { t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) continue diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index da6b9eaaf..7f33e0d9e 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -31,7 +31,6 @@ func main() { syscalls := seccomp.SyscallRules{ syscall.SYS_ACCEPT: {}, - syscall.SYS_ARCH_PRCTL: {}, syscall.SYS_BIND: {}, syscall.SYS_BRK: {}, syscall.SYS_CLOCK_GETTIME: {}, @@ -41,7 +40,6 @@ func main() { syscall.SYS_DUP3: {}, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, - syscall.SYS_EPOLL_WAIT: {}, syscall.SYS_EPOLL_PWAIT: {}, syscall.SYS_EXIT: {}, syscall.SYS_EXIT_GROUP: {}, @@ -68,8 +66,6 @@ func main() { syscall.SYS_MUNLOCK: {}, syscall.SYS_MUNMAP: {}, syscall.SYS_NANOSLEEP: {}, - syscall.SYS_NEWFSTATAT: {}, - syscall.SYS_OPEN: {}, syscall.SYS_PPOLL: {}, syscall.SYS_PREAD64: {}, syscall.SYS_PSELECT6: {}, @@ -97,11 +93,14 @@ func main() { syscall.SYS_WRITE: {}, syscall.SYS_WRITEV: {}, } + + arch_syscalls(syscalls) + die := *dieFlag if !die { syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{ { - seccomp.AllowValue(10), + seccomp.EqualTo(10), }, } } diff --git a/pkg/seccomp/seccomp_test_victim_amd64.go b/pkg/seccomp/seccomp_test_victim_amd64.go new file mode 100644 index 000000000..5dfc68e25 --- /dev/null +++ b/pkg/seccomp/seccomp_test_victim_amd64.go @@ -0,0 +1,32 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test binary used to test that seccomp filters are properly constructed and +// indeed kill the process on violation. + +// +build amd64 + +package main + +import ( + "gvisor.dev/gvisor/pkg/seccomp" + "syscall" +) + +func arch_syscalls(syscalls seccomp.SyscallRules) { + syscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{} + syscalls[syscall.SYS_EPOLL_WAIT] = []seccomp.Rule{} + syscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} + syscalls[syscall.SYS_OPEN] = []seccomp.Rule{} +} diff --git a/pkg/seccomp/seccomp_test_victim_arm64.go b/pkg/seccomp/seccomp_test_victim_arm64.go new file mode 100644 index 000000000..5184d8ac4 --- /dev/null +++ b/pkg/seccomp/seccomp_test_victim_arm64.go @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test binary used to test that seccomp filters are properly constructed and +// indeed kill the process on violation. + +// +build arm64 + +package main + +import ( + "gvisor.dev/gvisor/pkg/seccomp" + "syscall" +) + +func arch_syscalls(syscalls seccomp.SyscallRules) { + syscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{} +} diff --git a/pkg/segment/set.go b/pkg/segment/set.go index 1a17ad9cb..fbb31dbea 100644 --- a/pkg/segment/set.go +++ b/pkg/segment/set.go @@ -407,7 +407,9 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // and returns an iterator to the inserted segment. All existing iterators // (including gap, but not including the returned iterator) are invalidated. // -// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) @@ -1211,12 +1213,10 @@ func (seg Iterator) End() Key { // does not invalidate any iterators. // // Preconditions: -// -// - r.Length() > 0. -// -// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), -// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then -// r.start >= seg.PrevSegment().End(). +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). func (seg Iterator) SetRangeUnchecked(r Range) { seg.node.keys[seg.index] = r } @@ -1241,8 +1241,9 @@ func (seg Iterator) SetRange(r Range) { // SetStartUnchecked mutates the iterated segment's start. This operation does // not invalidate any iterators. // -// Preconditions: The new start must be valid: start < seg.End(); if -// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). func (seg Iterator) SetStartUnchecked(start Key) { seg.node.keys[seg.index].Start = start } @@ -1264,8 +1265,9 @@ func (seg Iterator) SetStart(start Key) { // SetEndUnchecked mutates the iterated segment's end. This operation does not // invalidate any iterators. // -// Preconditions: The new end must be valid: end > seg.Start(); if -// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). func (seg Iterator) SetEndUnchecked(end Key) { seg.node.keys[seg.index].End = end } @@ -1695,9 +1697,11 @@ func (s *Set) ExportSortedSlices() *SegmentDataSlices { // ImportSortedSlice initializes the given set from the given slice. // -// Preconditions: s must be empty. sds must represent a valid set (the segments -// in sds must have valid lengths that do not overlap). The segments in sds -// must be sorted in ascending key order. +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { if !s.IsEmpty() { return fmt.Errorf("cannot import into non-empty set %v", s) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 901e0f320..4af4d6e84 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -22,6 +22,7 @@ go_library( "signal_info.go", "signal_stack.go", "stack.go", + "stack_unsafe.go", "syscalls_amd64.go", "syscalls_arm64.go", ], @@ -33,11 +34,12 @@ go_library( "//pkg/context", "//pkg/cpuid", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/limits", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index a903d031c..d75d665ae 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -72,12 +73,12 @@ type Context interface { // with return values of varying sizes (for example ARCH_GETFS). This // is a simple utility function to convert to the native size in these // cases, and then we can CopyOut. - Native(val uintptr) interface{} + Native(val uintptr) marshal.Marshallable // Value converts a native type back to a generic value. // Once a value has been converted to native via the above call -- it // can be converted back here. - Value(val interface{}) uintptr + Value(val marshal.Marshallable) uintptr // Width returns the number of bytes for a native value. Width() uint @@ -205,7 +206,7 @@ type Context interface { // equivalent of arch_ptrace(): // PtracePeekUser implements ptrace(PTRACE_PEEKUSR). - PtracePeekUser(addr uintptr) (interface{}, error) + PtracePeekUser(addr uintptr) (marshal.Marshallable, error) // PtracePokeUser implements ptrace(PTRACE_POKEUSR). PtracePokeUser(addr, data uintptr) error diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index daba8b172..fd73751e7 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -28,7 +28,14 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs + + // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register. + TPIDR_EL0 uint64 +} const ( // SyscallWidth is the width of insturctions. @@ -94,6 +101,8 @@ func NewFloatingPointData() *FloatingPointData { // State contains the common architecture bits for aarch64 (the build tag of this // file ensures it's only built on aarch64). +// +// +stateify savable type State struct { // The system registers. Regs Registers @@ -101,9 +110,6 @@ type State struct { // Our floating point state. aarch64FPState `state:"wait"` - // TLS pointer - TPValue uint64 - // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -148,6 +154,7 @@ func (s State) Proto() *rpb.Registers { Sp: s.Regs.Sp, Pc: s.Regs.Pc, Pstate: s.Regs.Pstate, + Tls: s.Regs.TPIDR_EL0, } return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}} } @@ -157,7 +164,6 @@ func (s *State) Fork() State { return State{ Regs: s.Regs, aarch64FPState: s.aarch64FPState.fork(), - TPValue: s.TPValue, FeatureSet: s.FeatureSet, OrigR0: s.OrigR0, } @@ -227,6 +233,7 @@ func (s *State) RegisterMap() (map[string]uintptr, error) { "Sp": uintptr(s.Regs.Sp), "Pc": uintptr(s.Regs.Pc), "Pstate": uintptr(s.Regs.Pstate), + "Tls": uintptr(s.Regs.TPIDR_EL0), }, nil } @@ -241,18 +248,18 @@ func (s *State) ptraceGetRegs() Registers { return s.Regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } regs.UnmarshalUnsafe(buf) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // PtraceGetFPRegs implements Context.PtraceGetFPRegs. @@ -278,7 +285,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -291,7 +298,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 3b3a0a272..c7d3a206d 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -23,6 +23,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -179,14 +181,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { } // Native returns the native type for the given val. -func (c *context64) Native(val uintptr) interface{} { - v := uint64(val) +func (c *context64) Native(val uintptr) marshal.Marshallable { + v := primitive.Uint64(val) return &v } // Value returns the generic val for the given native type. -func (c *context64) Value(val interface{}) uintptr { - return uintptr(*val.(*uint64)) +func (c *context64) Value(val marshal.Marshallable) uintptr { + return uintptr(*val.(*primitive.Uint64)) } // Width returns the byte width of this architecture. @@ -293,14 +295,14 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { const userStructSize = 928 // PtracePeekUser implements Context.PtracePeekUser. -func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { +func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { if addr&7 != 0 || addr >= userStructSize { return nil, syscall.EIO } // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and // u_debugreg, returning 0 or silently no-oping for other fields // respectively. - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) @@ -315,7 +317,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { if addr&7 != 0 || addr >= userStructSize { return syscall.EIO } - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index ada7ac7b8..680d23a9f 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -22,6 +22,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -73,6 +75,8 @@ const ( ) // context64 represents an ARM64 context. +// +// +stateify savable type context64 struct { State sigFPState []aarch64FPState // fpstate to be restored on sigreturn. @@ -142,7 +146,7 @@ func (c *context64) SetStack(value uintptr) { // TLS returns the current TLS pointer. func (c *context64) TLS() uintptr { - return uintptr(c.TPValue) + return uintptr(c.Regs.TPIDR_EL0) } // SetTLS sets the current TLS pointer. Returns false if value is invalid. @@ -151,7 +155,7 @@ func (c *context64) SetTLS(value uintptr) bool { return false } - c.TPValue = uint64(value) + c.Regs.TPIDR_EL0 = uint64(value) return true } @@ -161,14 +165,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { } // Native returns the native type for the given val. -func (c *context64) Native(val uintptr) interface{} { - v := uint64(val) +func (c *context64) Native(val uintptr) marshal.Marshallable { + v := primitive.Uint64(val) return &v } // Value returns the generic val for the given native type. -func (c *context64) Value(val interface{}) uintptr { - return uintptr(*val.(*uint64)) +func (c *context64) Value(val marshal.Marshallable) uintptr { + return uintptr(*val.(*primitive.Uint64)) } // Width returns the byte width of this architecture. @@ -272,7 +276,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { } // PtracePeekUser implements Context.PtracePeekUser. -func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { +func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64. return c.Native(0), nil } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index dc458b37f..b9405b320 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -31,7 +31,11 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs +} // System-related constants for x86. const ( @@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers { return regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } @@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) { } regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // isUserSegmentSelector returns true if the given segment selector specifies a @@ -543,7 +547,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto index 60c027aab..2727ba08a 100644 --- a/pkg/sentry/arch/registers.proto +++ b/pkg/sentry/arch/registers.proto @@ -83,6 +83,7 @@ message ARM64Registers { uint64 sp = 32; uint64 pc = 33; uint64 pstate = 34; + uint64 tls = 35; } message Registers { oneof arch { diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go index 32173aa20..d3e2324a8 100644 --- a/pkg/sentry/arch/signal_act.go +++ b/pkg/sentry/arch/signal_act.go @@ -14,7 +14,7 @@ package arch -import "gvisor.dev/gvisor/tools/go_marshal/marshal" +import "gvisor.dev/gvisor/pkg/marshal" // Special values for SignalAct.Handler. const ( diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 6fb756f0e..72e07a988 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -17,17 +17,19 @@ package arch import ( - "encoding/binary" "math" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the // second argument to signal handlers set by signal(2). +// +// +marshal type SignalContext64 struct { R8 uint64 R9 uint64 @@ -68,6 +70,8 @@ const ( ) // UContext64 is equivalent to ucontext_t on 64-bit x86. +// +// +marshal type UContext64 struct { Flags uint64 Link uint64 @@ -172,12 +176,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // "... the value (%rsp+8) is always a multiple of 16 (...) when // control is transferred to the function entry point." - AMD64 ABI - ucSize := binary.Size(uc) - if ucSize < 0 { - // This can only happen if we've screwed up the definition of - // UContext64. - panic("can't get size of UContext64") - } + ucSize := uc.SizeBytes() // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128. frameSize := int(st.Arch.Width()) + ucSize + 128 frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8 @@ -195,18 +194,18 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt info.FixSignalCodeForUser() // Set up the stack frame. - infoAddr, err := st.Push(info) - if err != nil { + if _, err := info.CopyOut(st, StackBottomMagic); err != nil { return err } - ucAddr, err := st.Push(uc) - if err != nil { + infoAddr := st.Bottom + if _, err := uc.CopyOut(st, StackBottomMagic); err != nil { return err } + ucAddr := st.Bottom if act.HasRestorer() { // Push the restorer return address. // Note that this doesn't need to be popped. - if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil { + if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil { return err } } else { @@ -240,11 +239,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { // Copy out the stack frame. var uc UContext64 - if _, err := st.Pop(&uc); err != nil { + if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } var info SignalInfo - if _, err := st.Pop(&info); err != nil { + if _, err := info.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 642c79dda..7fde5d34e 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package arch import ( - "encoding/binary" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +26,8 @@ import ( // SignalContext64 is equivalent to struct sigcontext, the type passed as the // second argument to signal handlers set by signal(2). +// +// +marshal type SignalContext64 struct { FaultAddr uint64 Regs [31]uint64 @@ -36,6 +39,7 @@ type SignalContext64 struct { Reserved [3568]uint8 } +// +marshal type aarch64Ctx struct { Magic uint32 Size uint32 @@ -43,6 +47,8 @@ type aarch64Ctx struct { // FpsimdContext is equivalent to struct fpsimd_context on arm64 // (arch/arm64/include/uapi/asm/sigcontext.h). +// +// +marshal type FpsimdContext struct { Head aarch64Ctx Fpsr uint32 @@ -51,13 +57,15 @@ type FpsimdContext struct { } // UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h). +// +// +marshal type UContext64 struct { Flags uint64 Link uint64 Stack SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t - _pad [(1024 - 64) / 8]byte + _pad [120]byte // (1024 - 64) / 8 = 120 // sigcontext must be aligned to 16-byte _pad2 [8]byte // last for future expansion @@ -94,11 +102,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt }, Sigset: sigset, } - - ucSize := binary.Size(uc) - if ucSize < 0 { - panic("can't get size of UContext64") - } + ucSize := uc.SizeBytes() // frameSize = ucSize + sizeof(siginfo). // sizeof(siginfo) == 128. @@ -119,14 +123,14 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt info.FixSignalCodeForUser() // Set up the stack frame. - infoAddr, err := st.Push(info) - if err != nil { + if _, err := info.CopyOut(st, StackBottomMagic); err != nil { return err } - ucAddr, err := st.Push(uc) - if err != nil { + infoAddr := st.Bottom + if _, err := uc.CopyOut(st, StackBottomMagic); err != nil { return err } + ucAddr := st.Bottom // Set up registers. c.Regs.Sp = uint64(st.Bottom) @@ -147,11 +151,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { // Copy out the stack frame. var uc UContext64 - if _, err := st.Pop(&uc); err != nil { + if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } var info SignalInfo - if _, err := st.Pop(&info); err != nil { + if _, err := info.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go index 0fa738a1d..a1eae98f9 100644 --- a/pkg/sentry/arch/signal_stack.go +++ b/pkg/sentry/arch/signal_stack.go @@ -17,8 +17,8 @@ package arch import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) const ( diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 1108fa0bd..5f06c751d 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -15,14 +15,16 @@ package arch import ( - "encoding/binary" - "fmt" - "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" ) -// Stack is a simple wrapper around a usermem.IO and an address. +// Stack is a simple wrapper around a usermem.IO and an address. Stack +// implements marshal.CopyContext, and marshallable values can be pushed or +// popped from the stack through the marshal.Marshallable interface. +// +// Stack is not thread-safe. type Stack struct { // Our arch info. // We use this for automatic Native conversion of usermem.Addrs during @@ -34,105 +36,60 @@ type Stack struct { // Our current stack bottom. Bottom usermem.Addr -} -// Push pushes the given values on to the stack. -// -// (This method supports Addrs and treats them as native types.) -func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) { - for _, v := range vals { - - // We convert some types to well-known serializable quanities. - var norm interface{} - - // For array types, we will automatically add an appropriate - // terminal value. This is done simply to make the interface - // easier to use. - var term interface{} - - switch v.(type) { - case string: - norm = []byte(v.(string)) - term = byte(0) - case []int8, []uint8: - norm = v - term = byte(0) - case []int16, []uint16: - norm = v - term = uint16(0) - case []int32, []uint32: - norm = v - term = uint32(0) - case []int64, []uint64: - norm = v - term = uint64(0) - case []usermem.Addr: - // Special case: simply push recursively. - _, err := s.Push(s.Arch.Native(uintptr(0))) - if err != nil { - return 0, err - } - varr := v.([]usermem.Addr) - for i := len(varr) - 1; i >= 0; i-- { - _, err := s.Push(varr[i]) - if err != nil { - return 0, err - } - } - continue - case usermem.Addr: - norm = s.Arch.Native(uintptr(v.(usermem.Addr))) - default: - norm = v - } + // Scratch buffer used for marshalling to avoid having to repeatedly + // allocate scratch memory. + scratchBuf []byte +} - if term != nil { - _, err := s.Push(term) - if err != nil { - return 0, err - } - } +// scratchBufLen is the default length of Stack.scratchBuf. The +// largest structs the stack regularly serializes are arch.SignalInfo +// and arch.UContext64. We'll set the default size as the larger of +// the two, arch.UContext64. +var scratchBufLen = (*UContext64)(nil).SizeBytes() - c := binary.Size(norm) - if c < 0 { - return 0, fmt.Errorf("bad binary.Size for %T", v) - } - n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{}) - if err != nil || c != n { - return 0, err - } +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (s *Stack) CopyScratchBuffer(size int) []byte { + if len(s.scratchBuf) < size { + s.scratchBuf = make([]byte, size) + } + return s.scratchBuf[:size] +} +// StackBottomMagic is the special address callers must past to all stack +// marshalling operations to cause the src/dst address to be computed based on +// the current end of the stack. +const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1) + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes +// computes an appropriate address based on the current end of the +// stack. Callers use the sentinel address StackBottomMagic to marshal methods +// to indicate this. +func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) { + if sentinel != StackBottomMagic { + panic("Attempted to copy out to stack with absolute address") + } + c := len(b) + n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{}) + if err == nil && n == c { s.Bottom -= usermem.Addr(n) } - - return s.Bottom, nil + return n, err } -// Pop pops the given values off the stack. -// -// (This method supports Addrs and treats them as native types.) -func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) { - for _, v := range vals { - - vaddr, isVaddr := v.(*usermem.Addr) - - var n int - var err error - if isVaddr { - value := s.Arch.Native(uintptr(0)) - n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{}) - *vaddr = usermem.Addr(s.Arch.Value(value)) - } else { - n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{}) - } - if err != nil { - return 0, err - } - +// CopyInBytes implements marshal.CopyContext.CopyInBytes. CopyInBytes computes +// an appropriate address based on the current end of the stack. Callers must +// use the sentinel address StackBottomMagic to marshal methods to indicate +// this. +func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) { + if sentinel != StackBottomMagic { + panic("Attempted to copy in from stack with absolute address") + } + n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{}) + if err == nil { s.Bottom += usermem.Addr(n) } - - return s.Bottom, nil + return n, err } // Align aligns the stack to the given offset. @@ -142,6 +99,22 @@ func (s *Stack) Align(offset int) { } } +// PushNullTerminatedByteSlice writes bs to the stack, followed by an extra null +// byte at the end. On error, the contents of the stack and the bottom cursor +// are undefined. +func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) { + // Note: Stack grows up, so write the terminal null byte first. + nNull, err := primitive.CopyUint8Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + n, err := primitive.CopyByteSliceOut(s, StackBottomMagic, bs) + if err != nil { + return 0, err + } + return n + nNull, nil +} + // StackLayout describes the location of the arguments and environment on the // stack. type StackLayout struct { @@ -177,11 +150,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) l.EnvvEnd = s.Bottom envAddrs := make([]usermem.Addr, len(env)) for i := len(env) - 1; i >= 0; i-- { - addr, err := s.Push(env[i]) - if err != nil { + if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil { return StackLayout{}, err } - envAddrs[i] = addr + envAddrs[i] = s.Bottom } l.EnvvStart = s.Bottom @@ -189,11 +161,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) l.ArgvEnd = s.Bottom argAddrs := make([]usermem.Addr, len(args)) for i := len(args) - 1; i >= 0; i-- { - addr, err := s.Push(args[i]) - if err != nil { + if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil { return StackLayout{}, err } - argAddrs[i] = addr + argAddrs[i] = s.Bottom } l.ArgvStart = s.Bottom @@ -222,26 +193,26 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) auxv = append(auxv, usermem.Addr(a.Key), a.Value) } auxv = append(auxv, usermem.Addr(0)) - _, err := s.Push(auxv) + _, err := s.pushAddrSliceAndTerminator(auxv) if err != nil { return StackLayout{}, err } // Push environment. - _, err = s.Push(envAddrs) + _, err = s.pushAddrSliceAndTerminator(envAddrs) if err != nil { return StackLayout{}, err } // Push args. - _, err = s.Push(argAddrs) + _, err = s.pushAddrSliceAndTerminator(argAddrs) if err != nil { return StackLayout{}, err } // Push arg count. - _, err = s.Push(usermem.Addr(len(args))) - if err != nil { + lenP := s.Arch.Native(uintptr(len(args))) + if _, err = lenP.CopyOut(s, StackBottomMagic); err != nil { return StackLayout{}, err } diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go new file mode 100644 index 000000000..a90d297ee --- /dev/null +++ b/pkg/sentry/arch/stack_unsafe.go @@ -0,0 +1,69 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arch + +import ( + "reflect" + "runtime" + "unsafe" + + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/usermem" +) + +// pushAddrSliceAndTerminator copies a slices of addresses to the stack, and +// also pushes an extra null address element at the end of the slice. +// +// Internally, we unsafely transmute the slice type from the arch-dependent +// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to +// go-marshal. +// +// On error, the contents of the stack and the bottom cursor are undefined. +func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) { + // Note: Stack grows upwards, so push the terminator first. + srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + switch s.Arch.Width() { + case 8: + nNull, err := primitive.CopyUint64Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + var dst []uint64 + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstHdr.Data = srcHdr.Data + dstHdr.Len = srcHdr.Len + dstHdr.Cap = srcHdr.Cap + n, err := primitive.CopyUint64SliceOut(s, StackBottomMagic, dst) + // Ensures src doesn't get GCed until we're done using it through dst. + runtime.KeepAlive(src) + return n + nNull, err + case 4: + nNull, err := primitive.CopyUint32Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + var dst []uint32 + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstHdr.Data = srcHdr.Data + dstHdr.Len = srcHdr.Len + dstHdr.Cap = srcHdr.Cap + n, err := primitive.CopyUint32SliceOut(s, StackBottomMagic, dst) + // Ensure src doesn't get GCed until we're done using it through dst. + runtime.KeepAlive(src) + return n + nNull, err + default: + panic("Unsupported arch width") + } +} diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go index 8e5658c7a..dfd195a23 100644 --- a/pkg/sentry/contexttest/contexttest.go +++ b/pkg/sentry/contexttest/contexttest.go @@ -144,27 +144,7 @@ func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { // RootContext returns a Context that may be used in tests that need root // credentials. Uses ptrace as the platform.Platform. func RootContext(tb testing.TB) context.Context { - return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) -} - -// WithCreds returns a copy of ctx carrying creds. -func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context { - return &authContext{ctx, creds} -} - -type authContext struct { - context.Context - creds *auth.Credentials -} - -// Value implements context.Context. -func (ac *authContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return ac.creds - default: - return ac.Context.Value(key) - } + return auth.ContextWithCredentials(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) } // WithLimitSet returns a copy of ctx carrying l. diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 2c5d14be5..deaf5fa23 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -35,7 +35,6 @@ go_library( "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 663e51989..2bf3c45e1 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -49,6 +49,9 @@ type ProfileOpts struct { // - dump out the stack trace of current go routines. // sentryctl -pid <pid> pprof-goroutine type Profile struct { + // Kernel is the kernel under profile. It's immutable. + Kernel *kernel.Kernel + // mu protects the fields below. mu sync.Mutex @@ -57,9 +60,6 @@ type Profile struct { // traceFile is the current execution trace output file. traceFile *fd.FD - - // Kernel is the kernel under profile. - Kernel *kernel.Kernel } // StartCPUProfile is an RPC stub which starts recording the CPU profile in a diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 1bae7cfaf..668f47802 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -23,8 +23,8 @@ import ( "text/tabwriter" "time" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fdimport" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" @@ -139,7 +139,6 @@ func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { // Import file descriptors. fdTable := proc.Kernel.NewFDTable() - defer fdTable.DecRef() creds := auth.NewUserCredentials( args.KUID, @@ -177,6 +176,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI initArgs.MountNamespaceVFS2.IncRef() } ctx := initArgs.NewContext(proc.Kernel) + defer fdTable.DecRef(ctx) if kernel.VFS2Enabled { // Get the full path to the filename from the PATH env variable. @@ -203,27 +203,17 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI } initArgs.Filename = resolved - fds := make([]int, len(args.FilePayload.Files)) - for i, file := range args.FilePayload.Files { - if kernel.VFS2Enabled { - // Need to dup to remove ownership from os.File. - dup, err := unix.Dup(int(file.Fd())) - if err != nil { - return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err) - } - fds[i] = dup - } else { - // VFS1 dups the file on import. - fds[i] = int(file.Fd()) - } + fds, err := fd.NewFromFiles(args.Files) + if err != nil { + return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err) } + defer func() { + for _, fd := range fds { + _ = fd.Close() + } + }() ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, args.StdioIsPty, fds) if err != nil { - if kernel.VFS2Enabled { - for _, fd := range fds { - unix.Close(fd) - } - } return nil, 0, nil, nil, err } diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index f45b2bd2b..6ca9dc79f 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -256,7 +256,7 @@ func (m *MultiDevice) Load(key MultiDeviceKey, value uint64) bool { } if k, exists := m.rcache[value]; exists && k != key { // Should never happen. - panic("MultiDevice's caches are inconsistent") + panic(fmt.Sprintf("MultiDevice's caches are inconsistent, current: %+v, previous: %+v", key, k)) } // Cache value at key. diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD index abe58f818..4c8604d58 100644 --- a/pkg/sentry/devices/memdev/BUILD +++ b/pkg/sentry/devices/memdev/BUILD @@ -18,9 +18,10 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go index af66fe4dc..fece3e762 100644 --- a/pkg/sentry/devices/memdev/full.go +++ b/pkg/sentry/devices/memdev/full.go @@ -24,6 +24,8 @@ import ( const fullDevMinor = 7 // fullDevice implements vfs.Device for /dev/full. +// +// +stateify savable type fullDevice struct{} // Open implements vfs.Device.Open. @@ -38,6 +40,8 @@ func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // fullFD implements vfs.FileDescriptionImpl for /dev/full. +// +// +stateify savable type fullFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -46,7 +50,7 @@ type fullFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *fullFD) Release() { +func (fd *fullFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go index 92d3d71be..ff5837747 100644 --- a/pkg/sentry/devices/memdev/null.go +++ b/pkg/sentry/devices/memdev/null.go @@ -25,6 +25,8 @@ import ( const nullDevMinor = 3 // nullDevice implements vfs.Device for /dev/null. +// +// +stateify savable type nullDevice struct{} // Open implements vfs.Device.Open. @@ -39,6 +41,8 @@ func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // nullFD implements vfs.FileDescriptionImpl for /dev/null. +// +// +stateify savable type nullFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -47,7 +51,7 @@ type nullFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *nullFD) Release() { +func (fd *nullFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go index 6b81da5ef..ac943e3ba 100644 --- a/pkg/sentry/devices/memdev/random.go +++ b/pkg/sentry/devices/memdev/random.go @@ -30,6 +30,8 @@ const ( ) // randomDevice implements vfs.Device for /dev/random and /dev/urandom. +// +// +stateify savable type randomDevice struct{} // Open implements vfs.Device.Open. @@ -44,6 +46,8 @@ func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, } // randomFD implements vfs.FileDescriptionImpl for /dev/random. +// +// +stateify savable type randomFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -56,7 +60,7 @@ type randomFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *randomFD) Release() { +func (fd *randomFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index c6f15054d..1929e41cd 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -16,9 +16,10 @@ package memdev import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -26,6 +27,8 @@ import ( const zeroDevMinor = 5 // zeroDevice implements vfs.Device for /dev/zero. +// +// +stateify savable type zeroDevice struct{} // Open implements vfs.Device.Open. @@ -40,6 +43,8 @@ func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // zeroFD implements vfs.FileDescriptionImpl for /dev/zero. +// +// +stateify savable type zeroFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -48,7 +53,7 @@ type zeroFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *zeroFD) Release() { +func (fd *zeroFD) Release(context.Context) { // noop } @@ -79,11 +84,22 @@ func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) + if opts.Private || !opts.MaxPerms.Write { + // This mapping will never permit writing to the "underlying file" (in + // Linux terms, it isn't VM_SHARED), so implement it as an anonymous + // mapping, but back it with fd; this is what Linux does, and is + // actually application-visible because the resulting VMA will show up + // in /proc/[pid]/maps with fd.vfsfd.VirtualDentry()'s path rather than + // "/dev/zero (deleted)". + opts.Offset = 0 + opts.MappingIdentity = &fd.vfsfd + opts.MappingIdentity.IncRef() + return nil + } + tmpfsFD, err := tmpfs.NewZeroFile(ctx, auth.CredentialsFromContext(ctx), kernel.KernelFromContext(ctx).ShmMount(), opts.Length) if err != nil { return err } - opts.MappingIdentity = m - opts.Mappable = m - return nil + defer tmpfsFD.DecRef(ctx) + return tmpfsFD.ConfigureMMap(ctx, opts) } diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD index 12e49b58a..b4b6ca38a 100644 --- a/pkg/sentry/devices/ttydev/BUILD +++ b/pkg/sentry/devices/ttydev/BUILD @@ -11,6 +11,6 @@ go_library( "//pkg/context", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/vfs", - "//pkg/usermem", + "//pkg/syserror", ], ) diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go index fbb7fd92c..a287c65ca 100644 --- a/pkg/sentry/devices/ttydev/ttydev.go +++ b/pkg/sentry/devices/ttydev/ttydev.go @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ttydev implements devices for /dev/tty and (eventually) -// /dev/console. -// -// TODO(b/159623826): Support /dev/console. +// Package ttydev implements an unopenable vfs.Device for /dev/tty. package ttydev import ( @@ -23,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -33,48 +30,13 @@ const ( ) // ttyDevice implements vfs.Device for /dev/tty. +// +// +stateify savable type ttyDevice struct{} // Open implements vfs.Device.Open. func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd := &ttyFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ - UseDentryMetadata: true, - }); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - -// ttyFD implements vfs.FileDescriptionImpl for /dev/tty. -type ttyFD struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.DentryMetadataFileDescriptionImpl - vfs.NoLockFD -} - -// Release implements vfs.FileDescriptionImpl.Release. -func (fd *ttyFD) Release() {} - -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return 0, nil -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (fd *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return 0, nil -} - -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return src.NumBytes(), nil -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (fd *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return src.NumBytes(), nil + return nil, syserror.EIO } // Register registers all devices implemented by this package in vfsObj. diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index dfbd069af..0b701a289 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -64,12 +64,13 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg request := args[1].Uint() data := args[2].Pointer() + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + switch request { case linux.TUNSETIFF: - t := kernel.TaskFromContext(ctx) - if t == nil { - panic("Ioctl should be called from a task context") - } if !t.HasCapability(linux.CAP_NET_ADMIN) { return 0, syserror.EPERM } @@ -79,9 +80,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg } var req linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := req.CopyIn(t, data); err != nil { return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) @@ -97,9 +96,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg flags := fd.device.Flags() | linux.IFF_NOFILTER usermem.ByteOrder.PutUint16(req.Data[:], flags) - _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := req.CopyOut(t, data) return 0, err default: @@ -108,8 +105,8 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *tunFD) Release() { - fd.device.Release() +func (fd *tunFD) Release(ctx context.Context) { + fd.device.Release(ctx) } // PRead implements vfs.FileDescriptionImpl.PRead. @@ -160,8 +157,8 @@ func (fd *tunFD) EventUnregister(e *waiter.Entry) { fd.device.EventUnregister(e) } -// isNetTunSupported returns whether /dev/net/tun device is supported for s. -func isNetTunSupported(s inet.Stack) bool { +// IsNetTunSupported returns whether /dev/net/tun device is supported for s. +func IsNetTunSupported(s inet.Stack) bool { _, ok := s.(*netstack.Stack) return ok } diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD index 5e41ceb4e..6b4f8b0ed 100644 --- a/pkg/sentry/fdimport/BUILD +++ b/pkg/sentry/fdimport/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/fd", "//pkg/sentry/fs", "//pkg/sentry/fs/host", "//pkg/sentry/fsimpl/host", diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index a4199f9e9..314661475 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -15,7 +15,10 @@ package fdimport import ( + "fmt" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" @@ -25,8 +28,9 @@ import ( // Import imports a slice of FDs into the given FDTable. If console is true, // sets up TTY for the first 3 FDs in the slice representing stdin, stdout, -// stderr. Upon success, Import takes ownership of all FDs. -func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { +// stderr. Used FDs are either closed or released. It's safe for the caller to +// close any remaining files upon return. +func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { if kernel.VFS2Enabled { ttyFile, err := importVFS2(ctx, fdTable, console, fds) return nil, ttyFile, err @@ -35,7 +39,7 @@ func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []in return ttyFile, nil, err } -func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, error) { +func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, error) { var ttyFile *fs.File for appFD, hostFD := range fds { var appFile *fs.File @@ -44,11 +48,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] // Import the file as a host TTY file. if ttyFile == nil { var err error - appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */) + appFile, err = host.ImportFile(ctx, hostFD.FD(), true /* isTTY */) if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) + _ = hostFD.Close() // FD is dup'd i ImportFile. // Remember this in the TTY file, as we will // use it for the other stdio FDs. @@ -63,11 +68,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] } else { // Import the file as a regular host file. var err error - appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */) + appFile, err = host.ImportFile(ctx, hostFD.FD(), false /* isTTY */) if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) + _ = hostFD.Close() // FD is dup'd i ImportFile. } // Add the file to the FD map. @@ -82,8 +88,11 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] return ttyFile.FileOperations.(*host.TTYFileOperations), nil } -func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) { +func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []*fd.FD) (*hostvfs2.TTYFileDescription, error) { k := kernel.KernelFromContext(ctx) + if k == nil { + return nil, fmt.Errorf("cannot find kernel from context") + } var ttyFile *vfs.FileDescription for appFD, hostFD := range stdioFDs { @@ -93,11 +102,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi // Import the file as a host TTY file. if ttyFile == nil { var err error - appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, true /* isTTY */) + appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), true /* isTTY */) if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) + hostFD.Release() // FD is transfered to host FD. // Remember this in the TTY file, as we will use it for the other stdio // FDs. @@ -110,11 +120,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi } } else { var err error - appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, false /* isTTY */) + appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), false /* isTTY */) if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) + hostFD.Release() // FD is transfered to host FD. } if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil { diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ab1424c95..ff2fe6712 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -107,8 +107,7 @@ func copyUp(ctx context.Context, d *Dirent) error { // leave the upper filesystem filled with any number of parent directories // but the upper filesystem will never be in an inconsistent state. // -// Preconditions: -// - d.Inode.overlay is non-nil. +// Preconditions: d.Inode.overlay is non-nil. func copyUpLockedForRename(ctx context.Context, d *Dirent) error { for { // Did we race with another copy up or does there @@ -183,12 +182,12 @@ func doCopyUp(ctx context.Context, d *Dirent) error { // Returns a generic error on failure. // // Preconditions: -// - parent.Inode.overlay.upper must be non-nil. -// - next.Inode.overlay.copyMu must be locked writable. -// - next.Inode.overlay.lower must be non-nil. -// - next.Inode.overlay.lower.StableAttr.Type must be RegularFile, Directory, +// * parent.Inode.overlay.upper must be non-nil. +// * next.Inode.overlay.copyMu must be locked writable. +// * next.Inode.overlay.lower must be non-nil. +// * next.Inode.overlay.lower.StableAttr.Type must be RegularFile, Directory, // or Symlink. -// - upper filesystem must support setting file ownership and timestamps. +// * upper filesystem must support setting file ownership and timestamps. func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { // Extract the attributes of the file we wish to copy. attrs, err := next.Inode.overlay.lower.UnstableAttr(ctx) @@ -201,7 +200,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { parentUpper := parent.Inode.overlay.upper root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } // Create the file in the upper filesystem and get an Inode for it. @@ -212,7 +211,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { log.Warningf("copy up failed to create file: %v", err) return syserror.EIO } - defer childFile.DecRef() + defer childFile.DecRef(ctx) childUpperInode = childFile.Dirent.Inode case Directory: @@ -226,7 +225,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { cleanupUpper(ctx, parentUpper, next.name, werr) return syserror.EIO } - defer childUpper.DecRef() + defer childUpper.DecRef(ctx) childUpperInode = childUpper.Inode case Symlink: @@ -246,7 +245,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { cleanupUpper(ctx, parentUpper, next.name, werr) return syserror.EIO } - defer childUpper.DecRef() + defer childUpper.DecRef(ctx) childUpperInode = childUpper.Inode default: @@ -352,14 +351,14 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in if err != nil { return err } - defer upperFile.DecRef() + defer upperFile.DecRef(ctx) // Get a handle to the lower filesystem, which we will read from. lowerFile, err := overlayFile(ctx, lower, FileFlags{Read: true}) if err != nil { return err } - defer lowerFile.DecRef() + defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. buf := copyUpBuffers.Get().([]byte) diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 91792d9fe..c7a11eec1 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -126,7 +126,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { if err != nil { t.Fatalf("failed to create file %q: %v", name, err) } - defer f.DecRef() + defer f.DecRef(ctx) relname, _ := f.Dirent.FullName(lowerRoot) @@ -171,7 +171,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { if err != nil { t.Fatalf("failed to find %q: %v", f.name, err) } - defer d.DecRef() + defer d.DecRef(ctx) f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) if err != nil { diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index dc7ad075a..5f8c9b5a2 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -80,8 +80,8 @@ type netTunFileOperations struct { var _ fs.FileOperations = (*netTunFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (fops *netTunFileOperations) Release() { - fops.device.Release() +func (fops *netTunFileOperations) Release(ctx context.Context) { + fops.device.Release(ctx) } // Ioctl implements fs.FileOperations.Ioctl. @@ -89,12 +89,13 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u request := args[1].Uint() data := args[2].Pointer() + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + switch request { case linux.TUNSETIFF: - t := kernel.TaskFromContext(ctx) - if t == nil { - panic("Ioctl should be called from a task context") - } if !t.HasCapability(linux.CAP_NET_ADMIN) { return 0, syserror.EPERM } @@ -104,9 +105,7 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u } var req linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := req.CopyIn(t, data); err != nil { return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) @@ -122,9 +121,7 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u flags := fops.device.Flags() | linux.IFF_NOFILTER usermem.ByteOrder.PutUint16(req.Data[:], flags) - _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := req.CopyOut(t, data) return 0, err default: diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 65be12175..00c526b03 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -325,7 +325,7 @@ func (d *Dirent) SyncAll(ctx context.Context) { for _, w := range d.children { if child := w.Get(); child != nil { child.(*Dirent).SyncAll(ctx) - child.DecRef() + child.DecRef(ctx) } } } @@ -413,9 +413,9 @@ func (d *Dirent) descendantOf(p *Dirent) bool { // Inode.Lookup, otherwise walk will keep d.mu locked. // // Preconditions: -// - renameMu must be held for reading. -// - d.mu must be held. -// - name must must not contain "/"s. +// * renameMu must be held for reading. +// * d.mu must be held. +// * name must must not contain "/"s. func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnlock bool) (*Dirent, error) { if !IsDir(d.Inode.StableAttr) { return nil, syscall.ENOTDIR @@ -451,7 +451,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // which don't hold a hard reference on their parent (their parent holds a // hard reference on them, and they contain virtually no state). But this is // good house-keeping. - child.DecRef() + child.DecRef(ctx) return nil, syscall.ENOENT } @@ -468,20 +468,20 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // their pins on the child. Inotify doesn't properly support filesystems that // revalidate dirents (since watches are lost on revalidation), but if we fail // to unpin the watches child will never be GCed. - cd.Inode.Watches.Unpin(cd) + cd.Inode.Watches.Unpin(ctx, cd) // This child needs to be revalidated, fallthrough to unhash it. Make sure // to not leak a reference from Get(). // // Note that previous lookups may still have a reference to this stale child; // this can't be helped, but we can ensure that *new* lookups are up-to-date. - child.DecRef() + child.DecRef(ctx) } // Either our weak reference expired or we need to revalidate it. Unhash child first, we're // about to replace it. delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be @@ -512,12 +512,12 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // There are active references to the existing child, prefer it to the one we // retrieved from Lookup. Likely the Lookup happened very close to the insertion // of child, so considering one stale over the other is fairly arbitrary. - c.DecRef() + c.DecRef(ctx) // The child that was installed could be negative. if cd.IsNegative() { // If so, don't leak a reference and short circuit. - child.DecRef() + child.DecRef(ctx) return nil, syscall.ENOENT } @@ -531,7 +531,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // we did the Inode.Lookup. Fully drop the weak reference and fallback to using the child // we looked up. delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Give the looked up child a parent. We cannot kick out entries, since we just checked above @@ -577,9 +577,9 @@ func (d *Dirent) Walk(ctx context.Context, root *Dirent, name string) (*Dirent, // exists returns true if name exists in relation to d. // // Preconditions: -// - renameMu must be held for reading. -// - d.mu must be held. -// - name must must not contain "/"s. +// * renameMu must be held for reading. +// * d.mu must be held. +// * name must must not contain "/"s. func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool { child, err := d.walk(ctx, root, name, false /* may unlock */) if err != nil { @@ -587,7 +587,7 @@ func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool { return false } // Child exists. - child.DecRef() + child.DecRef(ctx) return true } @@ -622,7 +622,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi } child := file.Dirent - d.finishCreate(child, name) + d.finishCreate(ctx, child, name) // Return the reference and the new file. When the last reference to // the file is dropped, file.Dirent may no longer be cached. @@ -631,7 +631,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi // finishCreate validates the created file, adds it as a child of this dirent, // and notifies any watchers. -func (d *Dirent) finishCreate(child *Dirent, name string) { +func (d *Dirent) finishCreate(ctx context.Context, child *Dirent, name string) { // Sanity check c, its name must be consistent. if child.name != name { panic(fmt.Sprintf("create from %q to %q returned unexpected name %q", d.name, name, child.name)) @@ -650,14 +650,14 @@ func (d *Dirent) finishCreate(child *Dirent, name string) { panic(fmt.Sprintf("hashed child %q over a positive child", child.name)) } // Don't leak a reference. - old.DecRef() + old.DecRef(ctx) // Drop d's reference. - old.DecRef() + old.DecRef(ctx) } // Finally drop the useless weak reference on the floor. - w.Drop() + w.Drop(ctx) } d.Inode.Watches.Notify(name, linux.IN_CREATE, 0) @@ -686,17 +686,17 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c panic(fmt.Sprintf("hashed over a positive child %q", old.(*Dirent).name)) } // Don't leak a reference. - old.DecRef() + old.DecRef(ctx) // Drop d's reference. - old.DecRef() + old.DecRef(ctx) } // Unhash the negative Dirent, name needs to exist now. delete(d.children, name) // Finally drop the useless weak reference on the floor. - w.Drop() + w.Drop(ctx) } // Execute the create operation. @@ -756,7 +756,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans if e != nil { return e } - d.finishCreate(childDir, name) + d.finishCreate(ctx, childDir, name) return nil }) if err == syscall.EEXIST { @@ -901,7 +901,7 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent, // references to children. // // Preconditions: d.mu must be held. -func (d *Dirent) flush() { +func (d *Dirent) flush(ctx context.Context) { expired := make(map[string]*refs.WeakRef) for n, w := range d.children { // Call flush recursively on each child before removing our @@ -912,7 +912,7 @@ func (d *Dirent) flush() { if !cd.IsNegative() { // Flush the child. cd.mu.Lock() - cd.flush() + cd.flush(ctx) cd.mu.Unlock() // Allow the file system to drop extra references on child. @@ -920,13 +920,13 @@ func (d *Dirent) flush() { } // Don't leak a reference. - child.DecRef() + child.DecRef(ctx) } // Check if the child dirent is closed, and mark it as expired if it is. // We must call w.Get() again here, since the child could have been closed // by the calls to flush() and cache.Remove() in the above if-block. if child := w.Get(); child != nil { - child.DecRef() + child.DecRef(ctx) } else { expired[n] = w } @@ -935,7 +935,7 @@ func (d *Dirent) flush() { // Remove expired entries. for n, w := range expired { delete(d.children, n) - w.Drop() + w.Drop(ctx) } } @@ -977,7 +977,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err if !ok { panic("mount must mount over an existing dirent") } - weakRef.Drop() + weakRef.Drop(ctx) // Note that even though `d` is now hidden, it still holds a reference // to its parent. @@ -1002,13 +1002,13 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { if !ok { panic("mount must mount over an existing dirent") } - weakRef.Drop() + weakRef.Drop(ctx) // d is not reachable anymore, and hence not mounted anymore. d.mounted = false // Drop mount reference. - d.DecRef() + d.DecRef(ctx) return nil } @@ -1029,7 +1029,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath // Child does not exist. return err } - defer child.DecRef() + defer child.DecRef(ctx) // Remove cannot remove directories. if IsDir(child.Inode.StableAttr) { @@ -1055,7 +1055,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath atomic.StoreInt32(&child.deleted, 1) if w, ok := d.children[name]; ok { delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Allow the file system to drop extra references on child. @@ -1067,7 +1067,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath // inode may have other links. If this was the last link, the events for the // watch removal will be queued by the inode destructor. child.Inode.Watches.MarkUnlinked() - child.Inode.Watches.Unpin(child) + child.Inode.Watches.Unpin(ctx, child) d.Inode.Watches.Notify(name, linux.IN_DELETE, 0) return nil @@ -1100,7 +1100,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) // Child does not exist. return err } - defer child.DecRef() + defer child.DecRef(ctx) // RemoveDirectory can only remove directories. if !IsDir(child.Inode.StableAttr) { @@ -1121,7 +1121,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) atomic.StoreInt32(&child.deleted, 1) if w, ok := d.children[name]; ok { delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Allow the file system to drop extra references on child. @@ -1130,14 +1130,14 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) // Finally, let inotify know the child is being unlinked. Drop any extra // refs from inotify to this child dirent. child.Inode.Watches.MarkUnlinked() - child.Inode.Watches.Unpin(child) + child.Inode.Watches.Unpin(ctx, child) d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_DELETE, 0) return nil } // destroy closes this node and all children. -func (d *Dirent) destroy() { +func (d *Dirent) destroy(ctx context.Context) { if d.IsNegative() { // Nothing to tear-down and no parent references to drop, since a negative // Dirent does not take a references on its parent, has no Inode and no children. @@ -1153,19 +1153,19 @@ func (d *Dirent) destroy() { if c.(*Dirent).IsNegative() { // The parent holds both weak and strong refs in the case of // negative dirents. - c.DecRef() + c.DecRef(ctx) } // Drop the reference we just acquired in WeakRef.Get. - c.DecRef() + c.DecRef(ctx) } - w.Drop() + w.Drop(ctx) } d.children = nil allDirents.remove(d) // Drop our reference to the Inode. - d.Inode.DecRef() + d.Inode.DecRef(ctx) // Allow the Dirent to be GC'ed after this point, since the Inode may still // be referenced after the Dirent is destroyed (for instance by filesystem @@ -1175,7 +1175,7 @@ func (d *Dirent) destroy() { // Drop the reference we have on our parent if we took one. renameMu doesn't need to be // held because d can't be reparented without any references to it left. if d.parent != nil { - d.parent.DecRef() + d.parent.DecRef(ctx) } } @@ -1201,14 +1201,14 @@ func (d *Dirent) TryIncRef() bool { // DecRef decreases the Dirent's refcount and drops its reference on its mount. // // DecRef implements RefCounter.DecRef with destructor d.destroy. -func (d *Dirent) DecRef() { +func (d *Dirent) DecRef(ctx context.Context) { if d.Inode != nil { // Keep mount around, since DecRef may destroy d.Inode. msrc := d.Inode.MountSource - d.DecRefWithDestructor(d.destroy) + d.DecRefWithDestructor(ctx, d.destroy) msrc.DecDirentRefs() } else { - d.DecRefWithDestructor(d.destroy) + d.DecRefWithDestructor(ctx, d.destroy) } } @@ -1359,7 +1359,7 @@ func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error if err != nil { return err } - defer victim.DecRef() + defer victim.DecRef(ctx) return d.mayDelete(ctx, victim) } @@ -1411,7 +1411,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string if err != nil { return err } - defer renamed.DecRef() + defer renamed.DecRef(ctx) // Check that the renamed dirent is deletable. if err := oldParent.mayDelete(ctx, renamed); err != nil { @@ -1453,13 +1453,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // Check that we can delete replaced. if err := newParent.mayDelete(ctx, replaced); err != nil { - replaced.DecRef() + replaced.DecRef(ctx) return err } // Target should not be an ancestor of source. if oldParent.descendantOf(replaced) { - replaced.DecRef() + replaced.DecRef(ctx) // Note that Linux returns EINVAL if the source is an // ancestor of target, but ENOTEMPTY if the target is @@ -1470,7 +1470,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // Check that replaced is not a mount point. if replaced.isMountPointLocked() { - replaced.DecRef() + replaced.DecRef(ctx) return syscall.EBUSY } @@ -1478,11 +1478,11 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string oldIsDir := IsDir(renamed.Inode.StableAttr) newIsDir := IsDir(replaced.Inode.StableAttr) if !newIsDir && oldIsDir { - replaced.DecRef() + replaced.DecRef(ctx) return syscall.ENOTDIR } if !oldIsDir && newIsDir { - replaced.DecRef() + replaced.DecRef(ctx) return syscall.EISDIR } @@ -1493,13 +1493,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // open across renames is currently broken for multiple // reasons, so we flush all references on the replaced node and // its children. - replaced.Inode.Watches.Unpin(replaced) + replaced.Inode.Watches.Unpin(ctx, replaced) replaced.mu.Lock() - replaced.flush() + replaced.flush(ctx) replaced.mu.Unlock() // Done with replaced. - replaced.DecRef() + replaced.DecRef(ctx) } if err := renamed.Inode.Rename(ctx, oldParent, renamed, newParent, newName, replaced != nil); err != nil { @@ -1513,14 +1513,14 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // can't destroy oldParent (and try to retake its lock) because // Rename's caller must be holding a reference. newParent.IncRef() - oldParent.DecRef() + oldParent.DecRef(ctx) } if w, ok := newParent.children[newName]; ok { - w.Drop() + w.Drop(ctx) delete(newParent.children, newName) } if w, ok := oldParent.children[oldName]; ok { - w.Drop() + w.Drop(ctx) delete(oldParent.children, oldName) } @@ -1551,7 +1551,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // Same as replaced.flush above. renamed.mu.Lock() - renamed.flush() + renamed.flush(ctx) renamed.mu.Unlock() return nil diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 33de32c69..7d9dd717e 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -17,6 +17,7 @@ package fs import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -101,7 +102,7 @@ func (c *DirentCache) remove(d *Dirent) { panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d)) } c.list.Remove(d) - d.DecRef() + d.DecRef(context.Background()) c.currentSize-- if c.limit != nil { c.limit.dec() diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 98d69c6f2..176b894ba 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -51,7 +51,7 @@ func TestWalkPositive(t *testing.T) { t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1) } - d.DecRef() + d.DecRef(ctx) if got := root.ReadRefs(); got != 1 { t.Fatalf("root has a ref count of %d, want %d", got, 1) @@ -61,7 +61,7 @@ func TestWalkPositive(t *testing.T) { t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0) } - root.flush() + root.flush(ctx) if got := len(root.children); got != 0 { t.Fatalf("root has %d children, want %d", got, 0) @@ -114,7 +114,7 @@ func TestWalkNegative(t *testing.T) { t.Fatalf("child has a ref count of %d, want %d", got, 2) } - child.DecRef() + child.DecRef(ctx) if got := child.(*Dirent).ReadRefs(); got != 1 { t.Fatalf("child has a ref count of %d, want %d", got, 1) @@ -124,7 +124,7 @@ func TestWalkNegative(t *testing.T) { t.Fatalf("root has %d children, want %d", got, 1) } - root.DecRef() + root.DecRef(ctx) if got := root.ReadRefs(); got != 0 { t.Fatalf("root has a ref count of %d, want %d", got, 0) @@ -351,9 +351,9 @@ func TestRemoveExtraRefs(t *testing.T) { t.Fatalf("dirent has a ref count of %d, want %d", got, 1) } - d.DecRef() + d.DecRef(ctx) - test.root.flush() + test.root.flush(ctx) if got := len(test.root.children); got != 0 { t.Errorf("root has %d children, want %d", got, 0) @@ -403,8 +403,8 @@ func TestRenameExtraRefs(t *testing.T) { t.Fatalf("Rename got error %v, want nil", err) } - oldParent.flush() - newParent.flush() + oldParent.flush(ctx) + newParent.flush(ctx) // Expect to have only active references. if got := renamed.ReadRefs(); got != 1 { diff --git a/pkg/sentry/fs/dirent_state.go b/pkg/sentry/fs/dirent_state.go index f623d6c0e..67a35f0b2 100644 --- a/pkg/sentry/fs/dirent_state.go +++ b/pkg/sentry/fs/dirent_state.go @@ -18,6 +18,7 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" ) @@ -48,7 +49,7 @@ func (d *Dirent) saveChildren() map[string]*Dirent { for name, w := range d.children { if rc := w.Get(); rc != nil { // Drop the reference count obtain in w.Get() - rc.DecRef() + rc.DecRef(context.Background()) cd := rc.(*Dirent) if cd.IsNegative() { diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 9fce177ad..b99199798 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -115,7 +115,7 @@ func (p *pipeOperations) Readiness(mask waiter.EventMask) (eventMask waiter.Even } // Release implements fs.FileOperations.Release. -func (p *pipeOperations) Release() { +func (p *pipeOperations) Release(context.Context) { fdnotifier.RemoveFD(int32(p.file.FD())) p.file.Close() p.file = nil diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index e556da48a..b9cec4b13 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -182,7 +182,7 @@ func TestTryOpen(t *testing.T) { // Cleanup the state of the pipe, and remove the fd from the // fdnotifier. Sadly this needed to maintain the correctness // of other tests because the fdnotifier is global. - pipeOps.Release() + pipeOps.Release(ctx) } continue } @@ -191,7 +191,7 @@ func TestTryOpen(t *testing.T) { } if pipeOps != nil { // Same as above. - pipeOps.Release() + pipeOps.Release(ctx) } } } @@ -279,7 +279,7 @@ func TestPipeOpenUnblocksEventually(t *testing.T) { pipeOps, err := Open(ctx, opener, flags) if pipeOps != nil { // Same as TestTryOpen. - pipeOps.Release() + pipeOps.Release(ctx) } // Check that the partner opened the file successfully. @@ -325,7 +325,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) { ctx := contexttest.Context(t) pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) if pipeOps != nil { - pipeOps.Release() + pipeOps.Release(ctx) t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY) } if err != syserror.ErrWouldBlock { @@ -351,7 +351,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) { if pipeOps == nil { t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY) } - defer pipeOps.Release() + defer pipeOps.Release(ctx) if err != nil { t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err) @@ -471,14 +471,14 @@ func TestPipeHangup(t *testing.T) { f := <-fdchan if f < 0 { t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f) - pipeOps.Release() + pipeOps.Release(ctx) continue } if test.hangupSelf { // Hangup self and assert that our partner got the expected hangup // error. - pipeOps.Release() + pipeOps.Release(ctx) if test.flags.Read { // Partner is writer. @@ -490,7 +490,7 @@ func TestPipeHangup(t *testing.T) { } else { // Hangup our partner and expect us to get the hangup error. syscall.Close(f) - defer pipeOps.Release() + defer pipeOps.Release(ctx) if test.flags.Read { assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file) diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index a0082ecca..1c9e82562 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -98,10 +98,11 @@ func TestNewPipe(t *testing.T) { } f := fd.New(gfd) - p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer) + ctx := contexttest.Context(t) + p, err := newPipeOperations(ctx, nil, test.flags, f, test.readAheadBuffer) if p != nil { // This is necessary to remove the fd from the global fd notifier. - defer p.Release() + defer p.Release(ctx) } else { // If there is no p to DecRef on, because newPipeOperations failed, then the // file still needs to be closed. @@ -153,13 +154,14 @@ func TestPipeDestruction(t *testing.T) { syscall.Close(fds[1]) // Test the read end, but it doesn't really matter which. - p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil) + ctx := contexttest.Context(t) + p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, f, nil) if err != nil { f.Close() t.Fatalf("newPipeOperations got error %v, want nil", err) } // Drop our only reference, which should trigger the destructor. - p.Release() + p.Release(ctx) if fdnotifier.HasFD(int32(fds[0])) { t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0]) @@ -282,7 +284,7 @@ func TestPipeRequest(t *testing.T) { if err != nil { t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err) } - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) @@ -334,7 +336,7 @@ func TestPipeReadAheadBuffer(t *testing.T) { rfile.Close() t.Fatalf("newPipeOperations got error %v, want nil", err) } - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, @@ -380,7 +382,7 @@ func TestPipeReadsAccumulate(t *testing.T) { } // Don't forget to remove the fd from the fd notifier. Otherwise other tests will // likely be borked, because it's global :( - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, @@ -448,7 +450,7 @@ func TestPipeWritesAccumulate(t *testing.T) { } // Don't forget to remove the fd from the fd notifier. Otherwise other tests // will likely be borked, because it's global :( - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index ca41520b4..72ea70fcf 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -142,17 +142,17 @@ func NewFile(ctx context.Context, dirent *Dirent, flags FileFlags, fops FileOper } // DecRef destroys the File when it is no longer referenced. -func (f *File) DecRef() { - f.DecRefWithDestructor(func() { +func (f *File) DecRef(ctx context.Context) { + f.DecRefWithDestructor(ctx, func(context.Context) { // Drop BSD style locks. lockRng := lock.LockRange{Start: 0, End: lock.LockEOF} f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng) // Release resources held by the FileOperations. - f.FileOperations.Release() + f.FileOperations.Release(ctx) // Release a reference on the Dirent. - f.Dirent.DecRef() + f.Dirent.DecRef(ctx) // Only unregister if we are currently registered. There is nothing // to register if f.async is nil (this happens when async mode is @@ -460,7 +460,7 @@ func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) { func (f *File) MappedName(ctx context.Context) string { root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } name, _ := f.Dirent.FullName(root) return name diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index beba0f771..6ec721022 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -67,7 +67,7 @@ type SpliceOpts struct { // - File.Flags(): This value may change during the operation. type FileOperations interface { // Release release resources held by FileOperations. - Release() + Release(ctx context.Context) // Waitable defines how this File can be waited on for read and // write readiness. @@ -159,7 +159,9 @@ type FileOperations interface { // io provides access to the virtual memory space to which pointers in args // refer. // - // Preconditions: The AddressSpace (if any) that io refers to is activated. + // Preconditions: + // * The AddressSpace (if any) that io refers to is activated. + // * Must only be called from a task goroutine. Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) } diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index dcc1df38f..9dc58d5ff 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -54,7 +54,7 @@ func overlayFile(ctx context.Context, inode *Inode, flags FileFlags) (*File, err // Drop the extra reference on the Dirent. Now there's only one reference // on the dirent, either owned by f (if non-nil), or the Dirent is about // to be destroyed (if GetFile failed). - dirent.DecRef() + dirent.DecRef(ctx) return f, err } @@ -89,12 +89,12 @@ type overlayFileOperations struct { } // Release implements FileOperations.Release. -func (f *overlayFileOperations) Release() { +func (f *overlayFileOperations) Release(ctx context.Context) { if f.upper != nil { - f.upper.DecRef() + f.upper.DecRef(ctx) } if f.lower != nil { - f.lower.DecRef() + f.lower.DecRef(ctx) } } @@ -164,7 +164,7 @@ func (f *overlayFileOperations) Seek(ctx context.Context, file *File, whence See func (f *overlayFileOperations) Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error) { root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &DirCtx{ @@ -497,7 +497,7 @@ func readdirOne(ctx context.Context, d *Dirent) (map[string]DentAttr, error) { if err != nil { return nil, err } - defer dir.DecRef() + defer dir.DecRef(ctx) // Use a stub serializer to read the entries into memory. stubSerializer := &CollectEntriesSerializer{} @@ -521,10 +521,10 @@ type overlayMappingIdentity struct { } // DecRef implements AtomicRefCount.DecRef. -func (omi *overlayMappingIdentity) DecRef() { - omi.AtomicRefCount.DecRefWithDestructor(func() { - omi.overlayFile.DecRef() - omi.id.DecRef() +func (omi *overlayMappingIdentity) DecRef(ctx context.Context) { + omi.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) { + omi.overlayFile.DecRef(ctx) + omi.id.DecRef(ctx) }) } @@ -544,7 +544,7 @@ func (omi *overlayMappingIdentity) InodeID() uint64 { func (omi *overlayMappingIdentity) MappedName(ctx context.Context) string { root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } name, _ := omi.overlayFile.Dirent.FullName(root) return name diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 789369220..5fb419bcd 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -8,7 +8,6 @@ go_template_instance( out = "dirty_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "Dirty", @@ -25,14 +24,14 @@ go_template_instance( name = "frame_ref_set_impl", out = "frame_ref_set_impl.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "fsutil", prefix = "FrameRef", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "uint64", "Functions": "FrameRefSetFunctions", }, @@ -43,7 +42,6 @@ go_template_instance( out = "file_range_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "FileRange", @@ -86,7 +84,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/state", diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index c6cd45087..2c9446c1d 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) { // repeatedly until all bytes have been written. max is the true size of the // cached object; offsets beyond max will not be passed to writeAt, even if // they are marked dirty. -func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { var changedDirty bool defer func() { if changedDirty { @@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet // successful partial write, SyncDirtyAll will call it repeatedly until all // bytes have been written. max is the true size of the cached object; offsets // beyond max will not be passed to writeAt, even if they are marked dirty. -func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { dseg := dirty.FirstSegment() for dseg.Ok() { if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil { @@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max } // Preconditions: mr must be page-aligned. -func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() { wbr := cseg.Range().Intersect(mr) if max < wbr.Start { diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 08695391c..dc9efa5df 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -31,7 +31,7 @@ import ( type FileNoopRelease struct{} // Release is a no-op. -func (FileNoopRelease) Release() {} +func (FileNoopRelease) Release(context.Context) {} // SeekWithDirCursor is used to implement fs.FileOperations.Seek. If dirCursor // is not nil and the seek was on a directory, the cursor will be updated. @@ -296,7 +296,7 @@ func (sdfo *StaticDirFileOperations) IterateDir(ctx context.Context, d *fs.Diren func (sdfo *StaticDirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 5643cdac9..9197aeb88 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -23,13 +23,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a -// platform.File. It is used to implement Mappables that store data in +// memmap.File. It is used to implement Mappables that store data in // sparsely-allocated memory. // // type FileRangeSet <generated by go_generics> @@ -65,20 +64,22 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli } // FileRange returns the FileRange mapped by seg. -func (seg FileRangeIterator) FileRange() platform.FileRange { +func (seg FileRangeIterator) FileRange() memmap.FileRange { return seg.FileRangeOf(seg.Range()) } // FileRangeOf returns the FileRange mapped by mr. // -// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0. -func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange { +// Preconditions: +// * seg.Range().IsSupersetOf(mr). +// * mr.Length() != 0. +func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange { frstart := seg.Value() + (mr.Start - seg.Start()) - return platform.FileRange{frstart, frstart + mr.Length()} + return memmap.FileRange{frstart, frstart + mr.Length()} } // Fill attempts to ensure that all memmap.Mappable offsets in required are -// mapped to a platform.File offset, by allocating from mf with the given +// mapped to a memmap.File offset, by allocating from mf with the given // memory usage kind and invoking readAt to store data into memory. (If readAt // returns a successful partial read, Fill will call it repeatedly until all // bytes have been read.) EOF is handled consistently with the requirements of @@ -89,8 +90,10 @@ func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileR // outside of optional. It returns a non-nil error if any error occurs, even // if the error only affects offsets in optional, but not in required. // -// Preconditions: required.Length() > 0. optional.IsSupersetOf(required). -// required and optional must be page-aligned. +// Preconditions: +// * required.Length() > 0. +// * optional.IsSupersetOf(required). +// * required and optional must be page-aligned. func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error { gap := frs.LowerBoundGap(required.Start) for gap.Ok() && gap.Start() < required.End { @@ -141,7 +144,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map } // Drop removes segments for memmap.Mappable offsets in mr, freeing the -// corresponding platform.FileRanges. +// corresponding memmap.FileRanges. // // Preconditions: mr must be page-aligned. func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { @@ -154,7 +157,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { } // DropAll removes all segments in mr, freeing the corresponding -// platform.FileRanges. +// memmap.FileRanges. func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) { for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { mf.DecRef(seg.FileRange()) diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index dd6f5aba6..a808894df 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -17,7 +17,7 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" ) @@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) { } // Merge implements segment.Functions.Merge. -func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) { +func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) { if val1 != val2 { return 0, false } @@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. } // Split implements segment.Functions.Split. -func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { +func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } // IncRefAndAccount adds a reference on the range fr. All newly inserted segments // are accounted as host page cache memory mappings. -func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) { seg, gap := refs.Find(fr.Start) for { switch { @@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { // DecRefAndAccount removes a reference on the range fr and untracks segments // that are removed from memory accounting. -func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) { seg := refs.FindSegment(fr.Start) for seg.Ok() && seg.Start() < fr.End { diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e82afd112..1390a9a7f 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -81,7 +80,9 @@ func NewHostFileMapper() *HostFileMapper { // IncRefOn increments the reference count on all offsets in mr. // -// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned. +// Preconditions: +// * mr.Length() != 0. +// * mr.Start and mr.End must be page-aligned. func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() @@ -98,7 +99,9 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { // DecRefOn decrements the reference count on all offsets in mr. // -// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned. +// Preconditions: +// * mr.Length() != 0. +// * mr.Start and mr.End must be page-aligned. func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() @@ -126,7 +129,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { // offsets in fr or until the next call to UnmapAll. // // Preconditions: The caller must hold a reference on all offsets in fr. -func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) { +func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) { chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift) f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -146,7 +149,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) } // Preconditions: f.mapsMu must be locked. -func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error { +func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error { prot := syscall.PROT_READ if write { prot |= syscall.PROT_WRITE @@ -205,7 +208,9 @@ func (f *HostFileMapper) UnmapAll() { } } -// Preconditions: f.mapsMu must be locked. f.mappings[chunkStart] == m. +// Preconditions: +// * f.mapsMu must be locked. +// * f.mappings[chunkStart] == m. func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) { if _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m.addr, chunkSize, 0); errno != 0 { // This leaks address space and is unexpected, but is otherwise diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 78fec553e..c15d8a946 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -21,18 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// HostMappable implements memmap.Mappable and platform.File over a +// HostMappable implements memmap.Mappable and memmap.File over a // CachedFileObject. // // Lock order (compare the lock order model in mm/mm.go): // truncateMu ("fs locks") // mu ("memmap.Mappable locks not taken by Translate") -// ("platform.File locks") +// ("memmap.File locks") // backingFile ("CachedFileObject locks") // // +stateify savable @@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error { return nil } -// MapInternal implements platform.File.MapInternal. -func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (h *HostMappable) FD() int { return h.backingFile.FD() } -// IncRef implements platform.File.IncRef. -func (h *HostMappable) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (h *HostMappable) IncRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.IncRefOn(mr) } -// DecRef implements platform.File.DecRef. -func (h *HostMappable) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (h *HostMappable) DecRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.DecRefOn(mr) } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 800c8b4e1..9eb6f522e 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -26,7 +26,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -685,7 +684,9 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // maybeGrowFile grows the file's size if data has been written past the old // size. // -// Preconditions: rw.c.attrMu and rw.c.dataMu bust be locked. +// Preconditions: +// * rw.c.attrMu must be locked. +// * rw.c.dataMu must be locked. func (rw *inodeReadWriter) maybeGrowFile() { // If the write ends beyond the file's previous size, it causes the // file to grow. @@ -934,7 +935,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. c.mapsMu.Lock() @@ -999,10 +1000,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable } } -// IncRef implements platform.File.IncRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// IncRef implements memmap.File.IncRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { +func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg, gap := c.refs.Find(fr.Start) @@ -1024,10 +1025,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { } } -// DecRef implements platform.File.DecRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// DecRef implements memmap.File.DecRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { +func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg := c.refs.FindSegment(fr.Start) @@ -1046,15 +1047,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { c.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. This is used when we +// MapInternal implements memmap.File.MapInternal. This is used when we // directly map an underlying host fd and CachingInodeOperations is used as the -// platform.File during translation. -func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// memmap.File during translation. +func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// FD implements memmap.File.FD. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. func (c *CachingInodeOperations) FD() int { return c.backingFile.FD() diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md index 2ca84dd74..05e043583 100644 --- a/pkg/sentry/fs/g3doc/fuse.md +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -79,7 +79,7 @@ ops can be implemented in parallel. - Implement `/dev/fuse` - a character device used to establish an FD for communication between the sentry and the server daemon. -- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. +- Implement basic FUSE ops like `FUSE_INIT`. #### Read-only mount with basic file operations @@ -95,6 +95,103 @@ ops can be implemented in parallel. - Implement the remaining FUSE ops and decide if we can omit rarely used operations like ioctl. +### Design Details + +#### Lifecycle for a FUSE Request + +- User invokes a syscall +- Sentry prepares corresponding request + - If FUSE device is available + - Write the request in binary + - If FUSE device is full + - Kernel task blocked until available +- Sentry notifies the readers of fuse device that it's ready for read +- FUSE daemon reads the request and processes it +- Sentry waits until a reply is written to the FUSE device + - but returns directly for async requests +- FUSE daemon writes to the fuse device +- Sentry processes the reply + - For sync requests, unblock blocked kernel task + - For async requests, execute pre-specified callback if any +- Sentry returns the syscall to the user + +#### Channels and Queues for Requests in Different Stages + +`connection.initializedChan` + +- a channel that the requests issued before connection initialization blocks + on. + +`fd.queue` + +- a queue of requests that haven’t been read by the FUSE daemon yet. + +`fd.completions` + +- a map of the requests that have been prepared but not yet received a + response, including the ones on the `fd.queue`. + +`fd.waitQueue` + +- a queue of waiters that is waiting for the fuse device fd to be available, + such as the FUSE daemon. + +`fd.fullQueueCh` + +- a channel that the kernel task will be blocked on when the fd is not + available. + +#### Basic I/O Implementation + +Currently we have implemented basic functionalities of read and write for our +FUSE. We describe the design and ways to improve it here: + +##### Basic FUSE Read + +The vfs2 expects implementations of `vfs.FileDescriptionImpl.Read()` and +`vfs.FileDescriptionImpl.PRead()`. When a syscall is made, it will eventually +reach our implementation of those interface functions located at +`pkg/sentry/fsimpl/fuse/regular_file.go` for regular files. + +After validation checks of the input, sentry sends `FUSE_READ` requests to the +FUSE daemon. The FUSE daemon returns data after the `fuse_out_header` as the +responses. For the first version, we create a copy in kernel memory of those +data. They are represented as a byte slice in the marshalled struct. This +happens as a common process for all the FUSE responses at this moment at +`pkg/sentry/fsimpl/fuse/dev.go:writeLocked()`. We then directly copy from this +intermediate buffer to the input buffer provided by the read syscall. + +There is an extra requirement for FUSE: When mounting the FUSE fs, the mounter +or the FUSE daemon can specify a `max_read` or a `max_pages` parameter. They are +the upperbound of the bytes to read in each `FUSE_READ` request. We implemented +the code to handle the fragmented reads. + +To improve the performance: ideally we should have buffer cache to copy those +data from the responses of FUSE daemon into, as is also the design of several +other existing file system implementations for sentry, instead of a single-use +temporary buffer. Directly mapping the memory of one process to another could +also boost the performance, but to keep them isolated, we did not choose to do +so. + +##### Basic FUSE Write + +The vfs2 invokes implementations of `vfs.FileDescriptionImpl.Write()` and +`vfs.FileDescriptionImpl.PWrite()` on the regular file descriptor of FUSE when a +user makes write(2) and pwrite(2) syscall. + +For valid writes, sentry sends the bytes to write after a `FUSE_WRITE` header +(can be regarded as a request with 2 payloads) to the FUSE daemon. For the first +version, we allocate a buffer inside kernel memory to store the bytes from the +user, and copy directly from that buffer to the memory of FUSE daemon. This +happens at `pkg/sentry/fsimpl/fuse/dev.go:readLocked()` + +The parameters `max_write` and `max_pages` restrict the number of bytes in one +`FUSE_WRITE`. There are code handling fragmented writes in current +implementation. + +To have better performance: the extra copy created to store the bytes to write +can be replaced by the buffer cache as well. + # Appendix ## FUSE Protocol diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index b2fcab127..c0bc63a32 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -114,7 +114,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } // Release implements fs.FileOpeations.Release. -func (f *fileOperations) Release() { +func (f *fileOperations) Release(context.Context) { f.handles.DecRef() } @@ -122,7 +122,7 @@ func (f *fileOperations) Release() { func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go index 2df2fe889..326fed954 100644 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ b/pkg/sentry/fs/gofer/gofer_test.go @@ -232,7 +232,7 @@ func TestRevalidation(t *testing.T) { // We must release the dirent, of the test will fail // with a reference leak. This is tracked by p9test. - defer dirent.DecRef() + defer dirent.DecRef(ctx) // Walk again. Depending on the cache policy, we may // get a new dirent. @@ -246,7 +246,7 @@ func TestRevalidation(t *testing.T) { if !test.preModificationWantReload && dirent != newDirent { t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent) } - newDirent.DecRef() // See above. + newDirent.DecRef(ctx) // See above. // Modify the underlying mocked file's modification // time for the next walk that occurs. @@ -287,7 +287,7 @@ func TestRevalidation(t *testing.T) { if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds { t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds) } - newDirent.DecRef() // See above. + newDirent.DecRef(ctx) // See above. // Remove the file from the remote fs, subsequent walks // should now fail to find anything. @@ -303,7 +303,7 @@ func TestRevalidation(t *testing.T) { t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err) } if err == nil { - newDirent.DecRef() // See above. + newDirent.DecRef(ctx) // See above. } }) } diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index fc14249be..f324dbf26 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -47,7 +47,8 @@ type handles struct { // DecRef drops a reference on handles. func (h *handles) DecRef() { - h.DecRefWithDestructor(func() { + ctx := context.Background() + h.DecRefWithDestructor(ctx, func(context.Context) { if h.Host != nil { if h.isHostBorrowed { h.Host.Release() @@ -57,7 +58,7 @@ func (h *handles) DecRef() { } } } - if err := h.File.close(context.Background()); err != nil { + if err := h.File.close(ctx); err != nil { log.Warningf("error closing p9 file: %v", err) } }) diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 51d7368a1..3a225fd39 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -441,8 +441,9 @@ func (i *inodeOperations) Release(ctx context.Context) { // asynchronously. // // We use AsyncWithContext to avoid needing to allocate an extra - // anonymous function on the heap. - fs.AsyncWithContext(ctx, i.fileState.Release) + // anonymous function on the heap. We must use background context + // because the async work cannot happen on the task context. + fs.AsyncWithContext(context.Background(), i.fileState.Release) } // Mappable implements fs.InodeOperations.Mappable. diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index cf9800100..3c66dc3c2 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -168,7 +168,7 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string // Construct the positive Dirent. d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name) - defer d.DecRef() + defer d.DecRef(ctx) // Construct the new file, caching the handles if allowed. h := handles{ @@ -371,7 +371,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string // Find out if file being deleted is a socket or pipe that needs to be // removed from endpoint map. if d, err := i.Lookup(ctx, dir, name); err == nil { - defer d.DecRef() + defer d.DecRef(ctx) if fs.IsSocket(d.Inode.StableAttr) || fs.IsPipe(d.Inode.StableAttr) { switch iops := d.Inode.InodeOperations.(type) { @@ -392,7 +392,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string return err } if key != nil { - i.session().overrides.remove(*key) + i.session().overrides.remove(ctx, *key) } i.touchModificationAndStatusChangeTime(ctx, dir) diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index b5efc86f2..7cf3522ff 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -89,10 +89,10 @@ func (e *overrideMaps) addPipe(key device.MultiDeviceKey, d *fs.Dirent, inode *f // remove deletes the key from the maps. // // Precondition: maps must have been locked with 'lock'. -func (e *overrideMaps) remove(key device.MultiDeviceKey) { +func (e *overrideMaps) remove(ctx context.Context, key device.MultiDeviceKey) { endpoint := e.keyMap[key] delete(e.keyMap, key) - endpoint.dirent.DecRef() + endpoint.dirent.DecRef(ctx) } // lock blocks other addition and removal operations from happening while @@ -197,7 +197,7 @@ type session struct { } // Destroy tears down the session. -func (s *session) Destroy() { +func (s *session) Destroy(ctx context.Context) { s.client.Close() } @@ -329,7 +329,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF s.client, err = p9.NewClient(conn, s.msize, s.version) if err != nil { // Drop our reference on the session, it needs to be torn down. - s.DecRef() + s.DecRef(ctx) return nil, err } @@ -340,7 +340,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF ctx.UninterruptibleSleepFinish(false) if err != nil { // Same as above. - s.DecRef() + s.DecRef(ctx) return nil, err } @@ -348,7 +348,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF if err != nil { s.attach.close(ctx) // Same as above, but after we execute the Close request. - s.DecRef() + s.DecRef(ctx) return nil, err } @@ -393,13 +393,13 @@ func (s *session) fillKeyMap(ctx context.Context) error { // fillPathMap populates paths for overrides from dirents in direntMap // before save. -func (s *session) fillPathMap() error { +func (s *session) fillPathMap(ctx context.Context) error { unlock := s.overrides.lock() defer unlock() for _, endpoint := range s.overrides.keyMap { mountRoot := endpoint.dirent.MountRoot() - defer mountRoot.DecRef() + defer mountRoot.DecRef(ctx) dirPath, _ := endpoint.dirent.FullName(mountRoot) if dirPath == "" { return fmt.Errorf("error getting path from dirent") diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go index 2d398b753..48b423dd8 100644 --- a/pkg/sentry/fs/gofer/session_state.go +++ b/pkg/sentry/fs/gofer/session_state.go @@ -26,7 +26,8 @@ import ( // beforeSave is invoked by stateify. func (s *session) beforeSave() { if s.overrides != nil { - if err := s.fillPathMap(); err != nil { + ctx := &dummyClockContext{context.Background()} + if err := s.fillPathMap(ctx); err != nil { panic("failed to save paths to override map before saving" + err.Error()) } } diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 40f2c1cad..8a1c69ac2 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -134,14 +134,14 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect // We don't need the receiver. c.CloseRecv() - c.Release() + c.Release(ctx) return c, nil } // Release implements transport.BoundEndpoint.Release. -func (e *endpoint) Release() { - e.inode.DecRef() +func (e *endpoint) Release(ctx context.Context) { + e.inode.DecRef(ctx) } // Passcred implements transport.BoundEndpoint.Passcred. diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index aabce6cc9..1368014c4 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -30,7 +30,9 @@ go_library( "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", + "//pkg/iovec", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/secio", diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go index 39299b7e4..0d8d36afa 100644 --- a/pkg/sentry/fs/host/control.go +++ b/pkg/sentry/fs/host/control.go @@ -57,7 +57,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (c *scmRights) Release() { +func (c *scmRights) Release(ctx context.Context) { for _, fd := range c.fds { syscall.Close(fd) } diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go index 3e48b8b2c..86d1a87f0 100644 --- a/pkg/sentry/fs/host/file.go +++ b/pkg/sentry/fs/host/file.go @@ -110,7 +110,7 @@ func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool name := fmt.Sprintf("host:[%d]", inode.StableAttr.InodeID) dirent := fs.NewDirent(ctx, inode, name) - defer dirent.DecRef() + defer dirent.DecRef(ctx) if isTTY { return newTTYFile(ctx, dirent, flags, iops), nil @@ -169,7 +169,7 @@ func (f *fileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go index c507f57eb..41a23b5da 100644 --- a/pkg/sentry/fs/host/inode_test.go +++ b/pkg/sentry/fs/host/inode_test.go @@ -36,7 +36,7 @@ func TestCloseFD(t *testing.T) { if err != nil { t.Fatalf("Failed to create File: %v", err) } - file.DecRef() + file.DecRef(ctx) s := make([]byte, 10) if c, err := syscall.Read(p[0], s); c != 0 || err != nil { diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index cfb089e43..a2f3d5918 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -194,7 +194,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) } // Send implements transport.ConnectedEndpoint.Send. -func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -271,7 +271,7 @@ func (c *ConnectedEndpoint) EventUpdate() { } // Recv implements transport.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -318,7 +318,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek } // close releases all resources related to the endpoint. -func (c *ConnectedEndpoint) close() { +func (c *ConnectedEndpoint) close(context.Context) { fdnotifier.RemoveFD(int32(c.file.FD())) c.file.Close() c.file = nil @@ -374,8 +374,8 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { } // Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release. -func (c *ConnectedEndpoint) Release() { - c.ref.DecRefWithDestructor(c.close) +func (c *ConnectedEndpoint) Release(ctx context.Context) { + c.ref.DecRefWithDestructor(ctx, c.close) } // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index 5c18dbd5e..905afb50d 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -17,15 +17,12 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -76,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index affdbcacb..9d58ea448 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -67,11 +67,12 @@ func TestSocketIsBlocking(t *testing.T) { if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { t.Fatalf("Expected socket %v to be blocking", pair[1]) } - sock, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + sock, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) failed => %v", pair[0], err) } - defer sock.DecRef() + defer sock.DecRef(ctx) // Test that the socket now is non-blocking. if fl, err = getFl(pair[0]); err != nil { t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) @@ -93,11 +94,12 @@ func TestSocketWritev(t *testing.T) { if err != nil { t.Fatalf("host socket creation failed: %v", err) } - socket, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + socket, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer socket.DecRef() + defer socket.DecRef(ctx) buf := []byte("hello world\n") n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf)) if err != nil { @@ -115,11 +117,12 @@ func TestSocketWritevLen0(t *testing.T) { if err != nil { t.Fatalf("host socket creation failed: %v", err) } - socket, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + socket, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer socket.DecRef() + defer socket.DecRef(ctx) n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil)) if err != nil { t.Fatalf("socket writev failed: %v", err) @@ -136,11 +139,12 @@ func TestSocketSendMsgLen0(t *testing.T) { if err != nil { t.Fatalf("host socket creation failed: %v", err) } - sfile, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + sfile, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer sfile.DecRef() + defer sfile.DecRef(ctx) s := sfile.FileOperations.(socket.Socket) n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) @@ -158,18 +162,19 @@ func TestListen(t *testing.T) { if err != nil { t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) } - sfile1, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + sfile1, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer sfile1.DecRef() + defer sfile1.DecRef(ctx) socket1 := sfile1.FileOperations.(socket.Socket) - sfile2, err := newSocket(contexttest.Context(t), pair[1], false) + sfile2, err := newSocket(ctx, pair[1], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[1], err) } - defer sfile2.DecRef() + defer sfile2.DecRef(ctx) socket2 := sfile2.FileOperations.(socket.Socket) // Socketpairs can not be listened to. @@ -185,11 +190,11 @@ func TestListen(t *testing.T) { if err != nil { t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) } - sfile3, err := newSocket(contexttest.Context(t), sock, false) + sfile3, err := newSocket(ctx, sock, false) if err != nil { t.Fatalf("newSocket(%v) => %v", sock, err) } - defer sfile3.DecRef() + defer sfile3.DecRef(ctx) socket3 := sfile3.FileOperations.(socket.Socket) // This socket is not bound so we can't listen on it. @@ -237,9 +242,10 @@ func TestRelease(t *testing.T) { } c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} want := &ConnectedEndpoint{queue: c.queue} - want.ref.DecRef() + ctx := contexttest.Context(t) + want.ref.DecRef(ctx) fdnotifier.AddFD(int32(c.file.FD()), nil) - c.Release() + c.Release(ctx) if !reflect.DeepEqual(c, want) { t.Errorf("got = %#v, want = %#v", c, want) } diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index 5d4f312cf..c8231e0aa 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -65,10 +65,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) ( controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC if n > length { - return length, n, msg.Controllen, controlTrunc, err + return length, n, msg.Controllen, controlTrunc, nil } - return n, n, msg.Controllen, controlTrunc, err + return n, n, msg.Controllen, controlTrunc, nil } // fdWriteVec sends from bufs to fd. diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 82a02fcb2..1183727ab 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -53,7 +54,7 @@ type TTYFileOperations struct { func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File { return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{ fileOperations: fileOperations{iops: iops}, - termios: linux.DefaultSlaveTermios, + termios: linux.DefaultReplicaTermios, }) } @@ -113,16 +114,21 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme } // Release implements fs.FileOperations.Release. -func (t *TTYFileOperations) Release() { +func (t *TTYFileOperations) Release(ctx context.Context) { t.mu.Lock() t.fgProcessGroup = nil t.mu.Unlock() - t.fileOperations.Release() + t.fileOperations.Release(ctx) } // Ioctl implements fs.FileOperations.Ioctl. func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + return 0, syserror.ENOTTY + } + // Ignore arg[0]. This is the real FD: fd := t.fileOperations.iops.fileState.FD() ioctl := args[1].Uint64() @@ -132,9 +138,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = termios.CopyOut(task, args[2].Pointer()) return 0, err case linux.TCSETS, linux.TCSETSW, linux.TCSETSF: @@ -146,9 +150,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO } var termios linux.Termios - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := termios.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetTermios(fd, ioctl, &termios) @@ -173,10 +175,8 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // Map the ProcessGroup into a ProcessGroupID in the task's PID // namespace. - pgID := pidns.IDOfProcessGroup(t.fgProcessGroup) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }) + pgID := primitive.Int32(pidns.IDOfProcessGroup(t.fgProcessGroup)) + _, err := pgID.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSPGRP: @@ -184,11 +184,6 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // Equivalent to tcsetpgrp(fd, *argp). // Set the foreground process group ID of this terminal. - task := kernel.TaskFromContext(ctx) - if task == nil { - return 0, syserror.ENOTTY - } - t.mu.Lock() defer t.mu.Unlock() @@ -208,12 +203,11 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO return 0, syserror.ENOTTY } - var pgID kernel.ProcessGroupID - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgIDP primitive.Int32 + if _, err := pgIDP.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } + pgID := kernel.ProcessGroupID(pgIDP) // pgID must be non-negative. if pgID < 0 { @@ -242,9 +236,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = winsize.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSWINSZ: @@ -255,9 +247,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // background ones) can set the winsize. var winsize linux.Winsize - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := winsize.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetWinsize(fd, &winsize) @@ -358,7 +348,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e // // Linux ignores the result of kill_pgrp(). _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) - return kernel.ERESTARTSYS + return syserror.ERESTARTSYS } // LINT.ThenChange(../../fsimpl/host/tty.go) diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go index ce397a5e3..c143f4ce2 100644 --- a/pkg/sentry/fs/host/wait_test.go +++ b/pkg/sentry/fs/host/wait_test.go @@ -39,7 +39,7 @@ func TestWait(t *testing.T) { t.Fatalf("NewFile failed: %v", err) } - defer file.DecRef() + defer file.DecRef(ctx) r := file.Readiness(waiter.EventIn) if r != 0 { diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index a34fbc946..004910453 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -96,13 +96,12 @@ func NewInode(ctx context.Context, iops InodeOperations, msrc *MountSource, satt } // DecRef drops a reference on the Inode. -func (i *Inode) DecRef() { - i.DecRefWithDestructor(i.destroy) +func (i *Inode) DecRef(ctx context.Context) { + i.DecRefWithDestructor(ctx, i.destroy) } // destroy releases the Inode and releases the msrc reference taken. -func (i *Inode) destroy() { - ctx := context.Background() +func (i *Inode) destroy(ctx context.Context) { if err := i.WriteOut(ctx); err != nil { // FIXME(b/65209558): Mark as warning again once noatime is // properly supported. @@ -122,12 +121,12 @@ func (i *Inode) destroy() { i.Watches.targetDestroyed() if i.overlay != nil { - i.overlay.release() + i.overlay.release(ctx) } else { i.InodeOperations.Release(ctx) } - i.MountSource.DecRef() + i.MountSource.DecRef(ctx) } // Mappable calls i.InodeOperations.Mappable. @@ -271,7 +270,7 @@ func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, // SetXattr calls i.InodeOperations.SetXattr with i as the Inode. func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error { if i.overlay != nil { - return overlaySetxattr(ctx, i.overlay, d, name, value, flags) + return overlaySetXattr(ctx, i.overlay, d, name, value, flags) } return i.InodeOperations.SetXattr(ctx, i, name, value, flags) } diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go index efd3c962b..9911a00c2 100644 --- a/pkg/sentry/fs/inode_inotify.go +++ b/pkg/sentry/fs/inode_inotify.go @@ -17,6 +17,7 @@ package fs import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -136,11 +137,11 @@ func (w *Watches) Notify(name string, events, cookie uint32) { } // Unpin unpins dirent from all watches in this set. -func (w *Watches) Unpin(d *Dirent) { +func (w *Watches) Unpin(ctx context.Context, d *Dirent) { w.mu.RLock() defer w.mu.RUnlock() for _, watch := range w.ws { - watch.Unpin(d) + watch.Unpin(ctx, d) } } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 537c8d257..b16ab08ba 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -16,7 +16,6 @@ package fs import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -85,7 +84,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name upperInode = child.Inode upperInode.IncRef() } - child.DecRef() + child.DecRef(ctx) } // Are we done? @@ -108,7 +107,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name entry, err := newOverlayEntry(ctx, upperInode, nil, false) if err != nil { // Don't leak resources. - upperInode.DecRef() + upperInode.DecRef(ctx) parent.copyMu.RUnlock() return nil, false, err } @@ -129,7 +128,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name if err != nil && err != syserror.ENOENT { // Don't leak resources. if upperInode != nil { - upperInode.DecRef() + upperInode.DecRef(ctx) } parent.copyMu.RUnlock() return nil, false, err @@ -152,7 +151,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } } } - child.DecRef() + child.DecRef(ctx) } } @@ -183,7 +182,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // unnecessary because we don't need to copy-up and we will always // operate (e.g. read/write) on the upper Inode. if !IsDir(upperInode.StableAttr) { - lowerInode.DecRef() + lowerInode.DecRef(ctx) lowerInode = nil } } @@ -194,10 +193,10 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // Well, not quite, we failed at the last moment, how depressing. // Be sure not to leak resources. if upperInode != nil { - upperInode.DecRef() + upperInode.DecRef(ctx) } if lowerInode != nil { - lowerInode.DecRef() + lowerInode.DecRef(ctx) } parent.copyMu.RUnlock() return nil, false, err @@ -248,7 +247,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st // user) will clobber the real path for the underlying Inode. upperFile.Dirent.Inode.IncRef() upperDirent := NewTransientDirent(upperFile.Dirent.Inode) - upperFile.Dirent.DecRef() + upperFile.Dirent.DecRef(ctx) upperFile.Dirent = upperDirent // Create the overlay inode and dirent. We need this to construct the @@ -259,7 +258,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st // The overlay file created below with NewFile will take a reference on // the overlayDirent, and it should be the only thing holding a // reference at the time of creation, so we must drop this reference. - defer overlayDirent.DecRef() + defer overlayDirent.DecRef(ctx) // Create a new overlay file that wraps the upper file. flags.Pread = upperFile.Flags().Pread @@ -399,7 +398,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena if !replaced.IsNegative() && IsDir(replaced.Inode.StableAttr) { children, err := readdirOne(ctx, replaced) if err != nil { - replaced.DecRef() + replaced.DecRef(ctx) return err } @@ -407,12 +406,12 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena // included among the returned children, so we don't // need to bother checking for them. if len(children) > 0 { - replaced.DecRef() + replaced.DecRef(ctx) return syserror.ENOTEMPTY } } - replaced.DecRef() + replaced.DecRef(ctx) } } @@ -455,12 +454,12 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri // Grab the inode and drop the dirent, we don't need it. inode := d.Inode inode.IncRef() - d.DecRef() + d.DecRef(ctx) // Create a new overlay entry and dirent for the socket. entry, err := newOverlayEntry(ctx, inode, nil, false) if err != nil { - inode.DecRef() + inode.DecRef(ctx) return nil, err } // Use the parent's MountSource, since that corresponds to the overlay, @@ -539,7 +538,7 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin // Don't forward the value of the extended attribute if it would // unexpectedly change the behavior of a wrapping overlay layer. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return "", syserror.ENODATA } @@ -553,9 +552,9 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin return s, err } -func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error { +func overlaySetXattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error { // Don't allow changes to overlay xattrs through a setxattr syscall. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return syserror.EPERM } @@ -578,7 +577,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st for name := range names { // Same as overlayGetXattr, we shouldn't forward along // overlay attributes. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { delete(names, name) } } @@ -587,7 +586,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error { // Don't allow changes to overlay xattrs through a removexattr syscall. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return syserror.EPERM } @@ -672,7 +671,7 @@ func overlayGetlink(ctx context.Context, o *overlayEntry) (*Dirent, error) { // ground and claim that jumping around the filesystem like this // is not supported. name, _ := dirent.FullName(nil) - dirent.DecRef() + dirent.DecRef(ctx) // Claim that the path is not accessible. err = syserror.EACCES diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go index 389c219d6..aa9851b26 100644 --- a/pkg/sentry/fs/inode_overlay_test.go +++ b/pkg/sentry/fs/inode_overlay_test.go @@ -316,7 +316,7 @@ func TestCacheFlush(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) ctx = &rootContext{ Context: ctx, @@ -345,7 +345,7 @@ func TestCacheFlush(t *testing.T) { } // Drop the file reference. - file.DecRef() + file.DecRef(ctx) // Dirent should have 2 refs left. if got, want := dirent.ReadRefs(), 2; int(got) != want { @@ -361,7 +361,7 @@ func TestCacheFlush(t *testing.T) { } // Drop our ref. - dirent.DecRef() + dirent.DecRef(ctx) // We should be back to zero refs. if got, want := dirent.ReadRefs(), 0; int(got) != want { @@ -398,7 +398,7 @@ func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags if err != nil { return nil, err } - defer file.DecRef() + defer file.DecRef(ctx) // Wrap the file's FileOperations in a dirFile. fops := &dirFile{ FileOperations: file.FileOperations, diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index e3a715c1f..c5c07d564 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -80,7 +80,7 @@ func NewInotify(ctx context.Context) *Inotify { // Release implements FileOperations.Release. Release removes all watches and // frees all resources for an inotify instance. -func (i *Inotify) Release() { +func (i *Inotify) Release(ctx context.Context) { // We need to hold i.mu to avoid a race with concurrent calls to // Inotify.targetDestroyed from Watches. There's no risk of Watches // accessing this Inotify after the destructor ends, because we remove all @@ -93,7 +93,7 @@ func (i *Inotify) Release() { // the owner's destructor. w.target.Watches.Remove(w.ID()) // Don't leak any references to the target, held by pins in the watch. - w.destroy() + w.destroy(ctx) } } @@ -321,7 +321,7 @@ func (i *Inotify) AddWatch(target *Dirent, mask uint32) int32 { // // RmWatch looks up an inotify watch for the given 'wd' and configures the // target dirent to stop sending events to this inotify instance. -func (i *Inotify) RmWatch(wd int32) error { +func (i *Inotify) RmWatch(ctx context.Context, wd int32) error { i.mu.Lock() // Find the watch we were asked to removed. @@ -346,7 +346,7 @@ func (i *Inotify) RmWatch(wd int32) error { i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0)) // Remove all pins. - watch.destroy() + watch.destroy(ctx) return nil } diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go index 900cba3ca..605423d22 100644 --- a/pkg/sentry/fs/inotify_watch.go +++ b/pkg/sentry/fs/inotify_watch.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -105,12 +106,12 @@ func (w *Watch) Pin(d *Dirent) { // Unpin drops any extra refs held on dirent due to a previous Pin // call. Calling Unpin multiple times for the same dirent, or on a dirent // without a corresponding Pin call is a no-op. -func (w *Watch) Unpin(d *Dirent) { +func (w *Watch) Unpin(ctx context.Context, d *Dirent) { w.mu.Lock() defer w.mu.Unlock() if w.pins[d] { delete(w.pins, d) - d.DecRef() + d.DecRef(ctx) } } @@ -125,11 +126,11 @@ func (w *Watch) TargetDestroyed() { // this watch. Destroy does not cause any new events to be generated. The caller // is responsible for ensuring there are no outstanding references to this // watch. -func (w *Watch) destroy() { +func (w *Watch) destroy(ctx context.Context) { w.mu.Lock() defer w.mu.Unlock() for d := range w.pins { - d.DecRef() + d.DecRef(ctx) } w.pins = nil } diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go index 37bae6810..ee69b10e8 100644 --- a/pkg/sentry/fs/mount.go +++ b/pkg/sentry/fs/mount.go @@ -51,7 +51,7 @@ type MountSourceOperations interface { DirentOperations // Destroy destroys the MountSource. - Destroy() + Destroy(ctx context.Context) // Below are MountSourceOperations that do not conform to Linux. @@ -165,16 +165,16 @@ func (msrc *MountSource) DecDirentRefs() { } } -func (msrc *MountSource) destroy() { +func (msrc *MountSource) destroy(ctx context.Context) { if c := msrc.DirentRefs(); c != 0 { panic(fmt.Sprintf("MountSource with non-zero direntRefs is being destroyed: %d", c)) } - msrc.MountSourceOperations.Destroy() + msrc.MountSourceOperations.Destroy(ctx) } // DecRef drops a reference on the MountSource. -func (msrc *MountSource) DecRef() { - msrc.DecRefWithDestructor(msrc.destroy) +func (msrc *MountSource) DecRef(ctx context.Context) { + msrc.DecRefWithDestructor(ctx, msrc.destroy) } // FlushDirentRefs drops all references held by the MountSource on Dirents. @@ -264,7 +264,7 @@ func (*SimpleMountSourceOperations) ResetInodeMappings() {} func (*SimpleMountSourceOperations) SaveInodeMapping(*Inode, string) {} // Destroy implements MountSourceOperations.Destroy. -func (*SimpleMountSourceOperations) Destroy() {} +func (*SimpleMountSourceOperations) Destroy(context.Context) {} // Info defines attributes of a filesystem. type Info struct { diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go index 78e35b1e6..7badc75d6 100644 --- a/pkg/sentry/fs/mount_overlay.go +++ b/pkg/sentry/fs/mount_overlay.go @@ -115,9 +115,9 @@ func (o *overlayMountSourceOperations) SaveInodeMapping(inode *Inode, path strin } // Destroy drops references on the upper and lower MountSource. -func (o *overlayMountSourceOperations) Destroy() { - o.upper.DecRef() - o.lower.DecRef() +func (o *overlayMountSourceOperations) Destroy(ctx context.Context) { + o.upper.DecRef(ctx) + o.lower.DecRef(ctx) } // type overlayFilesystem is the filesystem for overlay mounts. diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go index a3d10770b..6c296f5d0 100644 --- a/pkg/sentry/fs/mount_test.go +++ b/pkg/sentry/fs/mount_test.go @@ -18,6 +18,7 @@ import ( "fmt" "testing" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/contexttest" ) @@ -32,13 +33,13 @@ func cacheReallyContains(cache *DirentCache, d *Dirent) bool { return false } -func mountPathsAre(root *Dirent, got []*Mount, want ...string) error { +func mountPathsAre(ctx context.Context, root *Dirent, got []*Mount, want ...string) error { gotPaths := make(map[string]struct{}, len(got)) gotStr := make([]string, len(got)) for i, g := range got { if groot := g.Root(); groot != nil { name, _ := groot.FullName(root) - groot.DecRef() + groot.DecRef(ctx) gotStr[i] = name gotPaths[name] = struct{}{} } @@ -69,7 +70,7 @@ func TestMountSourceOnlyCachedOnce(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } rootDirent := mm.Root() - defer rootDirent.DecRef() + defer rootDirent.DecRef(ctx) // Get a child of the root which we will mount over. Note that the // MockInodeOperations causes Walk to always succeed. @@ -125,7 +126,7 @@ func TestAllMountsUnder(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } rootDirent := mm.Root() - defer rootDirent.DecRef() + defer rootDirent.DecRef(ctx) // Add mounts at the following paths: paths := []string{ @@ -150,14 +151,14 @@ func TestAllMountsUnder(t *testing.T) { if err := mm.Mount(ctx, d, submountInode); err != nil { t.Fatalf("could not mount at %q: %v", p, err) } - d.DecRef() + d.DecRef(ctx) } // mm root should contain all submounts (and does not include the root mount). rootMnt := mm.FindMount(rootDirent) submounts := mm.AllMountsUnder(rootMnt) allPaths := append(paths, "/") - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil { t.Error(err) } @@ -181,9 +182,9 @@ func TestAllMountsUnder(t *testing.T) { if err != nil { t.Fatalf("could not find path %q in mount manager: %v", "/foo", err) } - defer d.DecRef() + defer d.DecRef(ctx) submounts = mm.AllMountsUnder(mm.FindMount(d)) - if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { t.Error(err) } @@ -193,9 +194,9 @@ func TestAllMountsUnder(t *testing.T) { if err != nil { t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err) } - defer waldo.DecRef() + defer waldo.DecRef(ctx) submounts = mm.AllMountsUnder(mm.FindMount(waldo)) - if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, "/waldo"); err != nil { t.Error(err) } } @@ -212,7 +213,7 @@ func TestUnmount(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } rootDirent := mm.Root() - defer rootDirent.DecRef() + defer rootDirent.DecRef(ctx) // Add mounts at the following paths: paths := []string{ @@ -240,7 +241,7 @@ func TestUnmount(t *testing.T) { if err := mm.Mount(ctx, d, submountInode); err != nil { t.Fatalf("could not mount at %q: %v", p, err) } - d.DecRef() + d.DecRef(ctx) } allPaths := make([]string, len(paths)+1) @@ -259,13 +260,13 @@ func TestUnmount(t *testing.T) { if err := mm.Unmount(ctx, d, false); err != nil { t.Fatalf("could not unmount at %q: %v", p, err) } - d.DecRef() + d.DecRef(ctx) // Remove the path that has been unmounted and the check that the remaining // mounts are still there. allPaths = allPaths[:len(allPaths)-1] submounts := mm.AllMountsUnder(rootMnt) - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil { t.Error(err) } } diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 3f2bd0e87..d741c4339 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -234,7 +234,7 @@ func (mns *MountNamespace) flushMountSourceRefsLocked() { // After destroy is called, the MountNamespace may continue to be referenced (for // example via /proc/mounts), but should free all resources and shouldn't have // Find* methods called. -func (mns *MountNamespace) destroy() { +func (mns *MountNamespace) destroy(ctx context.Context) { mns.mu.Lock() defer mns.mu.Unlock() @@ -247,13 +247,13 @@ func (mns *MountNamespace) destroy() { for _, mp := range mns.mounts { // Drop the mount reference on all mounted dirents. for ; mp != nil; mp = mp.previous { - mp.root.DecRef() + mp.root.DecRef(ctx) } } mns.mounts = nil // Drop reference on the root. - mns.root.DecRef() + mns.root.DecRef(ctx) // Ensure that root cannot be accessed via this MountNamespace any // more. @@ -265,8 +265,8 @@ func (mns *MountNamespace) destroy() { } // DecRef implements RefCounter.DecRef with destructor mns.destroy. -func (mns *MountNamespace) DecRef() { - mns.DecRefWithDestructor(mns.destroy) +func (mns *MountNamespace) DecRef(ctx context.Context) { + mns.DecRefWithDestructor(ctx, mns.destroy) } // withMountLocked prevents further walks to `node`, because `node` is about to @@ -312,7 +312,7 @@ func (mns *MountNamespace) Mount(ctx context.Context, mountPoint *Dirent, inode if err != nil { return err } - defer replacement.DecRef() + defer replacement.DecRef(ctx) // Set the mount's root dirent and id. parentMnt := mns.findMountLocked(mountPoint) @@ -394,7 +394,7 @@ func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly panic(fmt.Sprintf("Last mount in the chain must be a undo mount: %+v", prev)) } // Drop mount reference taken at the end of MountNamespace.Mount. - prev.root.DecRef() + prev.root.DecRef(ctx) } else { mns.mounts[prev.root] = prev } @@ -496,11 +496,11 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path // non-directory root is hopeless. if current != root { if !IsDir(current.Inode.StableAttr) { - current.DecRef() // Drop reference from above. + current.DecRef(ctx) // Drop reference from above. return nil, syserror.ENOTDIR } if err := current.Inode.CheckPermission(ctx, PermMask{Execute: true}); err != nil { - current.DecRef() // Drop reference from above. + current.DecRef(ctx) // Drop reference from above. return nil, err } } @@ -511,12 +511,12 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path // Allow failed walks to cache the dirent, because no // children will acquire a reference at the end. current.maybeExtendReference() - current.DecRef() + current.DecRef(ctx) return nil, err } // Drop old reference. - current.DecRef() + current.DecRef(ctx) if remainder != "" { // Ensure it's resolved, unless it's the last level. @@ -570,11 +570,11 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema case nil: // Make sure we didn't exhaust the traversal budget. if *remainingTraversals == 0 { - target.DecRef() + target.DecRef(ctx) return nil, syscall.ELOOP } - node.DecRef() // Drop the original reference. + node.DecRef(ctx) // Drop the original reference. return target, nil case syscall.ENOLINK: @@ -582,7 +582,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema return node, nil case ErrResolveViaReadlink: - defer node.DecRef() // See above. + defer node.DecRef(ctx) // See above. // First, check if we should traverse. if *remainingTraversals == 0 { @@ -608,7 +608,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema return d, err default: - node.DecRef() // Drop for err; see above. + node.DecRef(ctx) // Drop for err; see above. // Propagate the error. return nil, err diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go index a69b41468..975d6cbc9 100644 --- a/pkg/sentry/fs/mounts_test.go +++ b/pkg/sentry/fs/mounts_test.go @@ -51,7 +51,7 @@ func TestFindLink(t *testing.T) { } root := mm.Root() - defer root.DecRef() + defer root.DecRef(ctx) foo, err := root.Walk(ctx, root, "foo") if err != nil { t.Fatalf("Error walking to foo: %v", err) diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index a8ae7d81d..01a1235b8 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -86,13 +86,12 @@ func isXattrOverlay(name string) bool { // NewOverlayRoot produces the root of an overlay. // // Preconditions: -// -// - upper and lower must be non-nil. -// - upper must not be an overlay. -// - lower should not expose character devices, pipes, or sockets, because +// * upper and lower must be non-nil. +// * upper must not be an overlay. +// * lower should not expose character devices, pipes, or sockets, because // copying up these types of files is not supported. -// - lower must not require that file objects be revalidated. -// - lower must not have dynamic file/directory content. +// * lower must not require that file objects be revalidated. +// * lower must not have dynamic file/directory content. func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags MountSourceFlags) (*Inode, error) { if !IsDir(upper.StableAttr) { return nil, fmt.Errorf("upper Inode is a %v, not a directory", upper.StableAttr.Type) @@ -107,7 +106,7 @@ func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags Mount msrc := newOverlayMountSource(ctx, upper.MountSource, lower.MountSource, flags) overlay, err := newOverlayEntry(ctx, upper, lower, true) if err != nil { - msrc.DecRef() + msrc.DecRef(ctx) return nil, err } @@ -117,12 +116,11 @@ func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags Mount // NewOverlayRootFile produces the root of an overlay that points to a file. // // Preconditions: -// -// - lower must be non-nil. -// - lower should not expose character devices, pipes, or sockets, because +// * lower must be non-nil. +// * lower should not expose character devices, pipes, or sockets, because // copying up these types of files is not supported. Neither it can be a dir. -// - lower must not require that file objects be revalidated. -// - lower must not have dynamic file/directory content. +// * lower must not require that file objects be revalidated. +// * lower must not have dynamic file/directory content. func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode, flags MountSourceFlags) (*Inode, error) { if !IsRegular(lower.StableAttr) { return nil, fmt.Errorf("lower Inode is not a regular file") @@ -130,7 +128,7 @@ func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode, msrc := newOverlayMountSource(ctx, upperMS, lower.MountSource, flags) overlay, err := newOverlayEntry(ctx, nil, lower, true) if err != nil { - msrc.DecRef() + msrc.DecRef(ctx) return nil, err } return newOverlayInode(ctx, overlay, msrc), nil @@ -230,16 +228,16 @@ func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExist }, nil } -func (o *overlayEntry) release() { +func (o *overlayEntry) release(ctx context.Context) { // We drop a reference on upper and lower file system Inodes // rather than releasing them, because in-memory filesystems // may hold an extra reference to these Inodes so that they // stay in memory. if o.upper != nil { - o.upper.DecRef() + o.upper.DecRef(ctx) } if o.lower != nil { - o.lower.DecRef() + o.lower.DecRef(ctx) } } diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 77c2c5c0e..b8b2281a8 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go index 35972e23c..45523adf8 100644 --- a/pkg/sentry/fs/proc/fds.go +++ b/pkg/sentry/fs/proc/fds.go @@ -56,11 +56,11 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF // readDescriptors reads fds in the task starting at offset, and calls the // toDentAttr callback for each to get a DentAttr, which it then emits. This is // a helper for implementing fs.InodeOperations.Readdir. -func readDescriptors(t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) { +func readDescriptors(ctx context.Context, t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) { var fds []int32 t.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.GetFDs() + fds = fdTable.GetFDs(ctx) } }) @@ -116,7 +116,7 @@ func (f *fd) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error func (f *fd) Readlink(ctx context.Context, _ *fs.Inode) (string, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } n, _ := f.file.Dirent.FullName(root) return n, nil @@ -135,13 +135,7 @@ func (f *fd) Truncate(context.Context, *fs.Inode, int64) error { func (f *fd) Release(ctx context.Context) { f.Symlink.Release(ctx) - f.file.DecRef() -} - -// Close releases the reference on the file. -func (f *fd) Close() error { - f.file.DecRef() - return nil + f.file.DecRef(ctx) } // fdDir is an InodeOperations for /proc/TID/fd. @@ -227,7 +221,7 @@ func (f *fdDirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySer if f.isInfoFile { typ = fs.Symlink } - return readDescriptors(f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr { + return readDescriptors(ctx, f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr { return fs.GenericDentAttr(typ, device.ProcDevice) }) } @@ -261,7 +255,7 @@ func (fdid *fdInfoDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs // locks, and other data. For now we only have flags. // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt flags := file.Flags().ToLinux() | fdFlags.ToLinuxFileFlags() - file.DecRef() + file.DecRef(ctx) contents := []byte(fmt.Sprintf("flags:\t0%o\n", flags)) return newStaticProcInode(ctx, dir.MountSource, contents) }) diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index 1fc9c703c..6a63c47b3 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -47,7 +47,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { // The task has been destroyed. Nothing to show here. return } - defer rootDir.DecRef() + defer rootDir.DecRef(t) mnt := t.MountNamespace().FindMount(rootDir) if mnt == nil { @@ -64,7 +64,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { continue // No longer valid. } mountPath, desc := mroot.FullName(rootDir) - mroot.DecRef() + mroot.DecRef(t) if !desc { // MountSources that are not descendants of the chroot jail are ignored. continue @@ -97,7 +97,7 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se if mroot == nil { return // No longer valid. } - defer mroot.DecRef() + defer mroot.DecRef(ctx) // Format: // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue @@ -216,7 +216,7 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan if root == nil { return // No longer valid. } - defer root.DecRef() + defer root.DecRef(ctx) flags := root.Inode.MountSource.Flags opts := "rw" diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index bd18177d4..83a43aa26 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -419,7 +419,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } sfile := s.(*fs.File) if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { - s.DecRef() + s.DecRef(ctx) // Not a unix socket. continue } @@ -479,7 +479,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } fmt.Fprintf(&buf, "\n") - s.DecRef() + s.DecRef(ctx) } data := []seqfile.SeqData{ @@ -574,7 +574,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { - s.DecRef() + s.DecRef(ctx) // Not tcp4 sockets. continue } @@ -664,7 +664,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne fmt.Fprintf(&buf, "\n") - s.DecRef() + s.DecRef(ctx) } data := []seqfile.SeqData{ @@ -752,7 +752,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { - s.DecRef() + s.DecRef(ctx) // Not udp4 socket. continue } @@ -822,7 +822,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se fmt.Fprintf(&buf, "\n") - s.DecRef() + s.DecRef(ctx) } data := []seqfile.SeqData{ diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index c659224a7..77e0e1d26 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -213,7 +213,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // Add dot and dotdot. root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dot, dotdot := file.Dirent.GetDotAttrs(root) names = append(names, ".", "..") diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 702fdd392..e555672ad 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,7 +55,7 @@ type tcpMemInode struct { // size stores the tcp buffer size during save, and sets the buffer // size in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. + // a netstack instance is created on restore. size inet.TCPBufferSize // mu protects against concurrent reads/writes to files based on this @@ -258,6 +259,9 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque if src.NumBytes() == 0 { return 0, nil } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -272,6 +276,96 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled) } +// +stateify savable +type tcpRecovery struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + recovery inet.TCPLossRecovery +} + +func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ts := &tcpRecovery{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ts, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (*tcpRecovery) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (r *tcpRecovery) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &tcpRecoveryFile{ + tcpRecovery: r, + stack: r.stack, + }), nil +} + +// +stateify savable +type tcpRecoveryFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + tcpRecovery *tcpRecovery + + stack inet.Stack `state:"wait"` +} + +// Read implements fs.FileOperations.Read. +func (f *tcpRecoveryFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + recovery, err := f.stack.TCPRecovery() + if err != nil { + return 0, err + } + f.tcpRecovery.recovery = recovery + s := fmt.Sprintf("%d\n", f.tcpRecovery.recovery) + n, err := dst.CopyOut(ctx, []byte(s)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + f.tcpRecovery.recovery = inet.TCPLossRecovery(v) + if err := f.tcpRecovery.stack.SetTCPRecovery(f.tcpRecovery.recovery); err != nil { + return 0, err + } + return n, nil +} + func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -293,11 +387,125 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } +// ipForwarding implements fs.InodeOperations. +// +// ipForwarding is used to enable/disable packet forwarding of netstack. +// +// +stateify savable +type ipForwarding struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + + // enabled stores the IPv4 forwarding state on save. + // We must save/restore this here, since a netstack instance + // is created on restore. + enabled *bool +} + +func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &ipForwarding{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*ipForwarding) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// +stateify savable +type ipForwardingFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + ipf *ipForwarding + + stack inet.Stack `state:"wait"` +} + +// GetFile implements fs.InodeOperations.GetFile. +func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{ + stack: ipf.stack, + ipf: ipf, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + if f.ipf.enabled == nil { + enabled := f.stack.Forwarding(ipv4.ProtocolNumber) + f.ipf.enabled = &enabled + } + + val := "0\n" + if *f.ipf.enabled { + // Technically, this is not quite compatible with Linux. Linux + // stores these as an integer, so if you write "2" into + // ip_forward, you should get 2 back. + val = "1\n" + } + n, err := dst.CopyOut(ctx, []byte(val)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return n, err + } + if f.ipf.enabled == nil { + f.ipf.enabled = new(bool) + } + *f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. "tcp_sack": newTCPSackInode(ctx, msrc, s), + // Add ip_forward. + "ip_forward": newIPForwardingInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the @@ -351,6 +559,11 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem) } + // Add tcp_recovery. + if _, err := s.TCPRecovery(); err == nil { + contents["tcp_recovery"] = newTCPRecoveryInode(ctx, msrc, s) + } + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 6eba709c6..4cb4741af 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -14,7 +14,11 @@ package proc -import "fmt" +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" +) // beforeSave is invoked by stateify. func (t *tcpMemInode) beforeSave() { @@ -40,3 +44,12 @@ func (s *tcpSack) afterLoad() { } } } + +// afterLoad is invoked by stateify. +func (ipf *ipForwarding) afterLoad() { + if ipf.enabled != nil { + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) + } + } +} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 355e83d47..6ef5738e7 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -123,3 +123,76 @@ func TestConfigureRecvBufferSize(t *testing.T) { } } } + +// TestIPForwarding tests the implementation of +// /proc/sys/net/ipv4/ip_forwarding +func TestIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + ipf := &ipForwarding{stack: s} + file := &ipForwardingFile{ + stack: s, + ipf: ipf, + } + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) + } + + }) + } +} diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 4bbe90198..22d658acf 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -84,6 +84,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "auxv": newAuxvec(t, msrc), "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), "comm": newComm(t, msrc), + "cwd": newCwd(t, msrc), "environ": newExecArgInode(t, msrc, environExecArg), "exe": newExe(t, msrc), "fd": newFdDir(t, msrc), @@ -185,7 +186,7 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry // Serialize "." and "..". root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dot, dotdot := file.Dirent.GetDotAttrs(root) if err := dirCtx.DirEmit(".", dot); err != nil { @@ -295,11 +296,54 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { if err != nil { return "", err } - defer exec.DecRef() + defer exec.DecRef(ctx) return exec.PathnameWithDeleted(ctx), nil } +// cwd is an fs.InodeOperations symlink for the /proc/PID/cwd file. +// +// +stateify savable +type cwd struct { + ramfs.Symlink + + t *kernel.Task +} + +func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + cwdSymlink := &cwd{ + Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + t: t, + } + return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t) +} + +// Readlink implements fs.InodeOperations. +func (e *cwd) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { + if !kernel.ContextCanTrace(ctx, e.t, false) { + return "", syserror.EACCES + } + if err := checkTaskState(e.t); err != nil { + return "", err + } + cwd := e.t.FSContext().WorkingDirectory() + if cwd == nil { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer cwd.DecRef(ctx) + + root := fs.RootFromContext(ctx) + if root == nil { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + name, _ := cwd.FullName(root) + return name, nil +} + // namespaceSymlink represents a symlink in the namespacefs, such as the files // in /proc/<pid>/ns. // @@ -604,7 +648,7 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( var vss, rss, data uint64 s.t.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index bfa304552..f4fcddecb 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -219,7 +219,7 @@ func (d *Dir) Remove(ctx context.Context, _ *fs.Inode, name string) error { } // Remove our reference on the inode. - inode.DecRef() + inode.DecRef(ctx) return nil } @@ -250,7 +250,7 @@ func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) err } // Remove our reference on the inode. - inode.DecRef() + inode.DecRef(ctx) return nil } @@ -326,7 +326,7 @@ func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.F // Create the Dirent and corresponding file. created := fs.NewDirent(ctx, inode, name) - defer created.DecRef() + defer created.DecRef(ctx) return created.Inode.GetFile(ctx, created, flags) } @@ -412,11 +412,11 @@ func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, ol } // Release implements fs.InodeOperation.Release. -func (d *Dir) Release(_ context.Context) { +func (d *Dir) Release(ctx context.Context) { // Drop references on all children. d.mu.Lock() for _, i := range d.children { - i.DecRef() + i.DecRef(ctx) } d.mu.Unlock() } @@ -456,7 +456,7 @@ func (dfo *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirC func (dfo *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, @@ -473,13 +473,13 @@ func hasChildren(ctx context.Context, inode *fs.Inode) (bool, error) { // dropped when that dirent is destroyed. inode.IncRef() d := fs.NewTransientDirent(inode) - defer d.DecRef() + defer d.DecRef(ctx) file, err := inode.GetFile(ctx, d, fs.FileFlags{Read: true}) if err != nil { return false, err } - defer file.DecRef() + defer file.DecRef(ctx) ser := &fs.CollectEntriesSerializer{} if err := file.Readdir(ctx, ser); err != nil { @@ -530,7 +530,7 @@ func Rename(ctx context.Context, oldParent fs.InodeOperations, oldName string, n if err != nil { return err } - inode.DecRef() + inode.DecRef(ctx) } // Be careful, we may have already grabbed this mutex above. diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go index a6ed8b2c5..3e0d1e07e 100644 --- a/pkg/sentry/fs/ramfs/tree_test.go +++ b/pkg/sentry/fs/ramfs/tree_test.go @@ -67,7 +67,7 @@ func TestMakeDirectoryTree(t *testing.T) { continue } root := mm.Root() - defer mm.DecRef() + defer mm.DecRef(ctx) for _, p := range test.subdirs { maxTraversals := uint(0) diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 88c344089..f362ca9b6 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -55,7 +55,7 @@ type TimerOperations struct { func NewFile(ctx context.Context, c ktime.Clock) *fs.File { dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[timerfd]") // Release the initial dirent reference after NewFile takes a reference. - defer dirent.DecRef() + defer dirent.DecRef(ctx) tops := &TimerOperations{} tops.timer = ktime.NewTimer(c, tops) // Timerfds reject writes, but the Write flag must be set in order to @@ -65,7 +65,7 @@ func NewFile(ctx context.Context, c ktime.Clock) *fs.File { } // Release implements fs.FileOperations.Release. -func (t *TimerOperations) Release() { +func (t *TimerOperations) Release(context.Context) { t.timer.Destroy() } diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go index aaba35502..d4d613ea9 100644 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ b/pkg/sentry/fs/tmpfs/file_test.go @@ -46,7 +46,7 @@ func newFile(ctx context.Context) *fs.File { func TestGrow(t *testing.T) { ctx := contexttest.Context(t) f := newFile(ctx) - defer f.DecRef() + defer f.DecRef(ctx) abuf := bytes.Repeat([]byte{'a'}, 68) n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0) diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index b095312fe..998b697ca 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -16,6 +16,8 @@ package tmpfs import ( + "math" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -32,9 +34,15 @@ import ( var fsInfo = fs.Info{ Type: linux.TMPFS_MAGIC, + // tmpfs currently does not support configurable size limits. In Linux, + // such a tmpfs mount will return f_blocks == f_bfree == f_bavail == 0 from + // statfs(2). However, many applications treat this as having a size limit + // of 0. To work around this, claim to have a very large but non-zero size, + // chosen to ensure that BlockSize * Blocks does not overflow int64 (which + // applications may also handle incorrectly). // TODO(b/29637826): allow configuring a tmpfs size and enforce it. - TotalBlocks: 0, - FreeBlocks: 0, + TotalBlocks: math.MaxInt64 / usermem.PageSize, + FreeBlocks: math.MaxInt64 / usermem.PageSize, } // rename implements fs.InodeOperations.Rename for tmpfs nodes. diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5cb0e0417..e6d0eb359 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -10,13 +10,14 @@ go_library( "line_discipline.go", "master.go", "queue.go", - "slave.go", + "replica.go", "terminal.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 108654827..c2da80bc2 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -37,14 +37,14 @@ import ( // This indirectly manages all terminals within the mount. // // New Terminals are created by masterInodeOperations.GetFile, which registers -// the slave Inode in the this directory for discovery via Lookup/Readdir. The -// slave inode is unregistered when the master file is Released, as the slave +// the replica Inode in the this directory for discovery via Lookup/Readdir. The +// replica inode is unregistered when the master file is Released, as the replica // is no longer discoverable at that point. // // References on the underlying Terminal are held by masterFileOperations and -// slaveInodeOperations. +// replicaInodeOperations. // -// masterInodeOperations and slaveInodeOperations hold a pointer to +// masterInodeOperations and replicaInodeOperations hold a pointer to // dirInodeOperations, which is reference counted by the refcount their // corresponding Dirents hold on their parent (this directory). // @@ -76,16 +76,16 @@ type dirInodeOperations struct { // master is the master PTY inode. master *fs.Inode - // slaves contains the slave inodes reachable from the directory. + // replicas contains the replica inodes reachable from the directory. // - // A new slave is added by allocateTerminal and is removed by + // A new replica is added by allocateTerminal and is removed by // masterFileOperations.Release. // - // A reference is held on every slave in the map. - slaves map[uint32]*fs.Inode + // A reference is held on every replica in the map. + replicas map[uint32]*fs.Inode // dentryMap is a SortedDentryMap used to implement Readdir containing - // the master and all entries in slaves. + // the master and all entries in replicas. dentryMap *fs.SortedDentryMap // next is the next pty index to use. @@ -101,7 +101,7 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { d := &dirInodeOperations{ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0555), linux.DEVPTS_SUPER_MAGIC), msrc: m, - slaves: make(map[uint32]*fs.Inode), + replicas: make(map[uint32]*fs.Inode), dentryMap: fs.NewSortedDentryMap(nil), } // Linux devpts uses a default mode of 0000 for ptmx which can be @@ -132,8 +132,8 @@ func (d *dirInodeOperations) Release(ctx context.Context) { d.mu.Lock() defer d.mu.Unlock() - d.master.DecRef() - if len(d.slaves) != 0 { + d.master.DecRef(ctx) + if len(d.replicas) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) } } @@ -149,14 +149,14 @@ func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name str return fs.NewDirent(ctx, d.master, name), nil } - // Slave number? + // Replica number? n, err := strconv.ParseUint(name, 10, 32) if err != nil { // Not found. return nil, syserror.ENOENT } - s, ok := d.slaves[uint32(n)] + s, ok := d.replicas[uint32(n)] if !ok { return nil, syserror.ENOENT } @@ -236,7 +236,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e return nil, syserror.ENOMEM } - if _, ok := d.slaves[n]; ok { + if _, ok := d.replicas[n]; ok { panic(fmt.Sprintf("pty index collision; index %d already exists", n)) } @@ -244,41 +244,41 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e d.next++ // The reference returned by newTerminal is returned to the caller. - // Take another for the slave inode. + // Take another for the replica inode. t.IncRef() // Create a pts node. The owner is based on the context that opens // ptmx. creds := auth.CredentialsFromContext(ctx) uid, gid := creds.EffectiveKUID, creds.EffectiveKGID - slave := newSlaveInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666)) + replica := newReplicaInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666)) - d.slaves[n] = slave + d.replicas[n] = replica d.dentryMap.Add(strconv.FormatUint(uint64(n), 10), fs.DentAttr{ - Type: slave.StableAttr.Type, - InodeID: slave.StableAttr.InodeID, + Type: replica.StableAttr.Type, + InodeID: replica.StableAttr.InodeID, }) return t, nil } // masterClose is called when the master end of t is closed. -func (d *dirInodeOperations) masterClose(t *Terminal) { +func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) { d.mu.Lock() defer d.mu.Unlock() - // The slave end disappears from the directory when the master end is - // closed, even if the slave end is open elsewhere. + // The replica end disappears from the directory when the master end is + // closed, even if the replica end is open elsewhere. // // N.B. since we're using a backdoor method to remove a directory entry // we won't properly fire inotify events like Linux would. - s, ok := d.slaves[t.n] + s, ok := d.replicas[t.n] if !ok { panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d)) } - s.DecRef() - delete(d.slaves, t.n) + s.DecRef(ctx) + delete(d.replicas, t.n) d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10)) } @@ -322,7 +322,7 @@ func (df *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCt func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go index 8fe05ebe5..13f4901db 100644 --- a/pkg/sentry/fs/tty/fs.go +++ b/pkg/sentry/fs/tty/fs.go @@ -79,8 +79,8 @@ type superOperations struct{} // // It always returns true, forcing a Lookup for all entries. // -// Slave entries are dropped from dir when their master is closed, so an -// existing slave Dirent in the tree is not sufficient to guarantee that it +// Replica entries are dropped from dir when their master is closed, so an +// existing replica Dirent in the tree is not sufficient to guarantee that it // still exists on the filesystem. func (superOperations) Revalidate(context.Context, string, *fs.Inode, *fs.Inode) bool { return true @@ -108,4 +108,4 @@ func (superOperations) ResetInodeMappings() {} func (superOperations) SaveInodeMapping(*fs.Inode, string) {} // Destroy implements MountSourceOperations.Destroy. -func (superOperations) Destroy() {} +func (superOperations) Destroy(context.Context) {} diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 2e9dd2d55..b34f4a0eb 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -43,7 +44,7 @@ const ( ) // lineDiscipline dictates how input and output are handled between the -// pseudoterminal (pty) master and slave. It can be configured to alter I/O, +// pseudoterminal (pty) master and replica. It can be configured to alter I/O, // modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man // pages are good resources for how to affect the line discipline: // @@ -54,8 +55,8 @@ const ( // // lineDiscipline has a simple structure but supports a multitude of options // (see the above man pages). It consists of two queues of bytes: one from the -// terminal master to slave (the input queue) and one from slave to master (the -// output queue). When bytes are written to one end of the pty, the line +// terminal master to replica (the input queue) and one from replica to master +// (the output queue). When bytes are written to one end of the pty, the line // discipline reads the bytes, modifies them or takes special action if // required, and enqueues them to be read by the other end of the pty: // @@ -64,7 +65,7 @@ const ( // | (inputQueueWrite) +-------------+ (inputQueueRead) | // | | // | v -// masterFD slaveFD +// masterFD replicaFD // ^ | // | | // | output to terminal +--------------+ output from process | @@ -103,8 +104,8 @@ type lineDiscipline struct { // masterWaiter is used to wait on the master end of the TTY. masterWaiter waiter.Queue `state:"zerovalue"` - // slaveWaiter is used to wait on the slave end of the TTY. - slaveWaiter waiter.Queue `state:"zerovalue"` + // replicaWaiter is used to wait on the replica end of the TTY. + replicaWaiter waiter.Queue `state:"zerovalue"` } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { @@ -115,27 +116,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { } // getTermios gets the linux.Termios for the tty. -func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.RLock() defer l.termiosMu.RUnlock() // We must copy a Termios struct, not KernelTermios. t := l.termios.ToTermios() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyOut(task, args[2].Pointer()) return 0, err } // setTermios sets a linux.Termios for the tty. -func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.Lock() defer l.termiosMu.Unlock() oldCanonEnabled := l.termios.LEnabled(linux.ICANON) // We must copy a Termios struct, not KernelTermios. var t linux.Termios - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyIn(task, args[2].Pointer()) l.termios.FromTermios(t) // If canonical mode is turned off, move bytes from inQueue's wait @@ -146,27 +143,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc l.inQueue.pushWaitBufLocked(l) l.inQueue.readable = true l.inQueue.mu.Unlock() - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return 0, err } -func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyOut(t, args[2].Pointer()) return err } -func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyIn(t, args[2].Pointer()) return err } @@ -176,14 +169,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask { return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios) } -func (l *lineDiscipline) slaveReadiness() waiter.EventMask { +func (l *lineDiscipline) replicaReadiness() waiter.EventMask { l.termiosMu.RLock() defer l.termiosMu.RUnlock() return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios) } -func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.inQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { + return l.inQueue.readableSize(t, args) } func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -196,7 +189,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque if n > 0 { l.masterWaiter.Notify(waiter.EventOut) if pushed { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return n, nil } @@ -211,14 +204,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) return n, nil } return 0, syserror.ErrWouldBlock } -func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.outQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { + return l.outQueue.readableSize(t, args) } func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -229,7 +222,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventOut) + l.replicaWaiter.Notify(waiter.EventOut) if pushed { l.masterWaiter.Notify(waiter.EventIn) } diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index fe07fa929..b91184b1b 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -17,9 +17,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -75,7 +77,7 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn } // Release implements fs.InodeOperations.Release. -func (mi *masterInodeOperations) Release(ctx context.Context) { +func (mi *masterInodeOperations) Release(context.Context) { } // Truncate implements fs.InodeOperations.Truncate. @@ -120,9 +122,9 @@ type masterFileOperations struct { var _ fs.FileOperations = (*masterFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (mf *masterFileOperations) Release() { - mf.d.masterClose(mf.t) - mf.t.DecRef() +func (mf *masterFileOperations) Release(ctx context.Context) { + mf.d.masterClose(ctx, mf.t) + mf.t.DecRef(ctx) } // EventRegister implements waiter.Waitable.EventRegister. @@ -152,46 +154,51 @@ func (mf *masterFileOperations) Write(ctx context.Context, _ *fs.File, src userm // Ioctl implements fs.FileOperations.Ioctl. func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the output queue read buffer. - return 0, mf.t.ld.outputQueueReadSize(ctx, io, args) + return 0, mf.t.ld.outputQueueReadSize(t, args) case linux.TCGETS: // N.B. TCGETS on the master actually returns the configuration - // of the slave end. - return mf.t.ld.getTermios(ctx, io, args) + // of the replica end. + return mf.t.ld.getTermios(t, args) case linux.TCSETS: // N.B. TCSETS on the master actually affects the configuration - // of the slave end. - return mf.t.ld.setTermios(ctx, io, args) + // of the replica end. + return mf.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return mf.t.ld.setTermios(ctx, io, args) + return mf.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mf.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(mf.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCSPTLCK: // TODO(b/29356795): Implement pty locking. For now just pretend we do. return 0, nil case linux.TIOCGWINSZ: - return 0, mf.t.ld.windowSize(ctx, io, args) + return 0, mf.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, mf.t.ld.setWindowSize(ctx, io, args) + return 0, mf.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mf.t.setControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mf.t.releaseControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + return mf.t.foregroundProcessGroup(ctx, args, true /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) + return mf.t.setForegroundProcessGroup(ctx, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index ceabb9b1e..79975d812 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -17,8 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -32,7 +34,7 @@ import ( const waitBufMaxBytes = 131072 // queue represents one of the input or output queues between a pty master and -// slave. Bytes written to a queue are added to the read buffer until it is +// replica. Bytes written to a queue are added to the read buffer until it is // full, at which point they are written to the wait buffer. Bytes are // processed (i.e. undergo termios transformations) as they are added to the // read buffer. The read buffer is readable when its length is nonzero and @@ -85,17 +87,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask { } // readableSize writes the number of readable bytes to userspace. -func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (q *queue) readableSize(t *kernel.Task, args arch.SyscallArguments) error { q.mu.Lock() defer q.mu.Unlock() - var size int32 + size := primitive.Int32(0) if q.readable { - size = int32(len(q.readBuf)) + size = primitive.Int32(len(q.readBuf)) } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := size.CopyOut(t, args[2].Pointer()) return err } @@ -104,8 +104,7 @@ func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.Sysca // as whether the read caused more readable data to become available (whether // data was pushed from the wait buffer to the read buffer). // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) { q.mu.Lock() defer q.mu.Unlock() @@ -145,8 +144,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl // write writes to q from userspace. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) { q.mu.Lock() defer q.mu.Unlock() @@ -188,8 +186,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip // writeBytes writes to q from b. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) writeBytes(b []byte, l *lineDiscipline) { q.mu.Lock() defer q.mu.Unlock() diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/replica.go index 9871f6fc6..385d230fb 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/replica.go @@ -17,9 +17,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -27,11 +29,11 @@ import ( // LINT.IfChange -// slaveInodeOperations are the fs.InodeOperations for the slave end of the +// replicaInodeOperations are the fs.InodeOperations for the replica end of the // Terminal (pts file). // // +stateify savable -type slaveInodeOperations struct { +type replicaInodeOperations struct { fsutil.SimpleFileInode // d is the containing dir. @@ -41,13 +43,13 @@ type slaveInodeOperations struct { t *Terminal } -var _ fs.InodeOperations = (*slaveInodeOperations)(nil) +var _ fs.InodeOperations = (*replicaInodeOperations)(nil) -// newSlaveInode creates an fs.Inode for the slave end of a terminal. +// newReplicaInode creates an fs.Inode for the replica end of a terminal. // -// newSlaveInode takes ownership of t. -func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode { - iops := &slaveInodeOperations{ +// newReplicaInode takes ownership of t. +func newReplicaInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode { + iops := &replicaInodeOperations{ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC), d: d, t: t, @@ -64,18 +66,18 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne Type: fs.CharacterDevice, // See fs/devpts/inode.c:devpts_fill_super. BlockSize: 1024, - DeviceFileMajor: linux.UNIX98_PTY_SLAVE_MAJOR, + DeviceFileMajor: linux.UNIX98_PTY_REPLICA_MAJOR, DeviceFileMinor: t.n, }) } // Release implements fs.InodeOperations.Release. -func (si *slaveInodeOperations) Release(ctx context.Context) { - si.t.DecRef() +func (si *replicaInodeOperations) Release(ctx context.Context) { + si.t.DecRef(ctx) } // Truncate implements fs.InodeOperations.Truncate. -func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { +func (*replicaInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { return nil } @@ -83,14 +85,15 @@ func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { // // This may race with destruction of the terminal. If the terminal is gone, it // returns ENOENT. -func (si *slaveInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, d, flags, &slaveFileOperations{si: si}), nil +func (si *replicaInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &replicaFileOperations{si: si}), nil } -// slaveFileOperations are the fs.FileOperations for the slave end of a terminal. +// replicaFileOperations are the fs.FileOperations for the replica end of a +// terminal. // // +stateify savable -type slaveFileOperations struct { +type replicaFileOperations struct { fsutil.FilePipeSeek `state:"nosave"` fsutil.FileNotDirReaddir `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` @@ -100,79 +103,84 @@ type slaveFileOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` // si is the inode operations. - si *slaveInodeOperations + si *replicaInodeOperations } -var _ fs.FileOperations = (*slaveFileOperations)(nil) +var _ fs.FileOperations = (*replicaFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (sf *slaveFileOperations) Release() { +func (sf *replicaFileOperations) Release(context.Context) { } // EventRegister implements waiter.Waitable.EventRegister. -func (sf *slaveFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - sf.si.t.ld.slaveWaiter.EventRegister(e, mask) +func (sf *replicaFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + sf.si.t.ld.replicaWaiter.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (sf *slaveFileOperations) EventUnregister(e *waiter.Entry) { - sf.si.t.ld.slaveWaiter.EventUnregister(e) +func (sf *replicaFileOperations) EventUnregister(e *waiter.Entry) { + sf.si.t.ld.replicaWaiter.EventUnregister(e) } // Readiness implements waiter.Waitable.Readiness. -func (sf *slaveFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return sf.si.t.ld.slaveReadiness() +func (sf *replicaFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return sf.si.t.ld.replicaReadiness() } // Read implements fs.FileOperations.Read. -func (sf *slaveFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { +func (sf *replicaFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { return sf.si.t.ld.inputQueueRead(ctx, dst) } // Write implements fs.FileOperations.Write. -func (sf *slaveFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { +func (sf *replicaFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { return sf.si.t.ld.outputQueueWrite(ctx, src) } // Ioctl implements fs.FileOperations.Ioctl. -func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (sf *replicaFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the input queue read buffer. - return 0, sf.si.t.ld.inputQueueReadSize(ctx, io, args) + return 0, sf.si.t.ld.inputQueueReadSize(t, args) case linux.TCGETS: - return sf.si.t.ld.getTermios(ctx, io, args) + return sf.si.t.ld.getTermios(t, args) case linux.TCSETS: - return sf.si.t.ld.setTermios(ctx, io, args) + return sf.si.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return sf.si.t.ld.setTermios(ctx, io, args) + return sf.si.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sf.si.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(sf.si.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCGWINSZ: - return 0, sf.si.t.ld.windowSize(ctx, io, args) + return 0, sf.si.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, sf.si.t.ld.setWindowSize(ctx, io, args) + return 0, sf.si.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sf.si.t.setControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sf.si.t.releaseControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + return sf.si.t.foregroundProcessGroup(ctx, args, false /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) + return sf.si.t.setForegroundProcessGroup(ctx, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY } } -// LINT.ThenChange(../../fsimpl/devpts/slave.go) +// LINT.ThenChange(../../fsimpl/devpts/replica.go) diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index ddcccf4da..4f431d74d 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,10 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -44,19 +44,19 @@ type Terminal struct { // this terminal. This field is immutable. masterKTTY *kernel.TTY - // slaveKTTY contains the controlling process of the slave end of this + // replicaKTTY contains the controlling process of the replica end of this // terminal. This field is immutable. - slaveKTTY *kernel.TTY + replicaKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { - termios := linux.DefaultSlaveTermios + termios := linux.DefaultReplicaTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{Index: n}, - slaveKTTY: &kernel.TTY{Index: n}, + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + replicaKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t @@ -64,7 +64,7 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal // setControllingTTY makes tm the controlling terminal of the calling thread // group. -func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("setControllingTTY must be called from a task context") @@ -75,7 +75,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a // releaseControllingTTY removes tm as the controlling terminal of the calling // thread group. -func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("releaseControllingTTY must be called from a task context") @@ -85,7 +85,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar } // foregroundProcessGroup gets the process group ID of tm's foreground process. -func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("foregroundProcessGroup must be called from a task context") @@ -97,24 +97,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a } // Write it out to *arg. - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ - AddressSpaceActive: true, - }) + retP := primitive.Int32(ret) + _, err = retP.CopyOut(task, args[2].Pointer()) return 0, err } // foregroundProcessGroup sets tm's foreground process. -func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("setForegroundProcessGroup must be called from a task context") } // Read in the process group ID. - var pgid int32 - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgid primitive.Int32 + if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } @@ -126,7 +123,7 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { if isMaster { return tm.masterKTTY } - return tm.slaveKTTY + return tm.replicaKTTY } // LINT.ThenChange(../../fsimpl/devpts/terminal.go) diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go index 2cbc05678..49edee83d 100644 --- a/pkg/sentry/fs/tty/tty_test.go +++ b/pkg/sentry/fs/tty/tty_test.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) +func TestSimpleMasterToReplica(t *testing.T) { + ld := newLineDiscipline(linux.DefaultReplicaTermios) ctx := contexttest.Context(t) inBytes := []byte("hello, tty\n") src := usermem.BytesIOSequence(inBytes) diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go index 397e96045..2f5a43b84 100644 --- a/pkg/sentry/fs/user/path.go +++ b/pkg/sentry/fs/user/path.go @@ -82,7 +82,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s // Caller has no root. Don't bother traversing anything. return "", syserror.ENOENT } - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range paths { if !path.IsAbs(p) { // Relative paths aren't safe, no one should be using them. @@ -100,7 +100,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s if err != nil { return "", err } - defer d.DecRef() + defer d.DecRef(ctx) // Check that it is a regular file. if !fs.IsRegular(d.Inode.StableAttr) { @@ -121,7 +121,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) { root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range paths { if !path.IsAbs(p) { // Relative paths aren't safe, no one should be using them. @@ -148,7 +148,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam if err != nil { return "", err } - dentry.DecRef() + dentry.DecRef(ctx) return binPath, nil } diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go index f4d525523..936fd3932 100644 --- a/pkg/sentry/fs/user/user.go +++ b/pkg/sentry/fs/user/user.go @@ -62,7 +62,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K // doesn't exist we will return the default home directory. return defaultHome, nil } - defer dirent.DecRef() + defer dirent.DecRef(ctx) // Check read permissions on the file. if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Read: true}); err != nil { @@ -81,7 +81,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K if err != nil { return "", err } - defer f.DecRef() + defer f.DecRef(ctx) r := &fileReader{ Ctx: ctx, @@ -105,7 +105,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth. const defaultHome = "/" root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) creds := auth.CredentialsFromContext(ctx) @@ -123,7 +123,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth. if err != nil { return defaultHome, nil } - defer fd.DecRef() + defer fd.DecRef(ctx) r := &fileReaderVFS2{ ctx: ctx, diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go index 7d8e9ac7c..12b786224 100644 --- a/pkg/sentry/fs/user/user_test.go +++ b/pkg/sentry/fs/user/user_test.go @@ -39,7 +39,7 @@ func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode if err != nil { return err } - defer etc.DecRef() + defer etc.DecRef(ctx) switch mode.FileType() { case 0: // Don't create anything. @@ -49,7 +49,7 @@ func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode if err != nil { return err } - defer passwd.DecRef() + defer passwd.DecRef(ctx) if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil { return err } @@ -110,9 +110,9 @@ func TestGetExecUserHome(t *testing.T) { if err != nil { t.Fatalf("NewMountNamespace failed: %v", err) } - defer mns.DecRef() + defer mns.DecRef(ctx) root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) ctx = fs.WithRoot(ctx, root) if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil { diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go index 8e7590721..7e61209ee 100644 --- a/pkg/sentry/fsbridge/bridge.go +++ b/pkg/sentry/fsbridge/bridge.go @@ -44,7 +44,7 @@ type File interface { IncRef() // DecRef decrements reference. - DecRef() + DecRef(ctx context.Context) } // Lookup provides a common interface to open files. diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go index 093ce1fb3..9785fd62a 100644 --- a/pkg/sentry/fsbridge/fs.go +++ b/pkg/sentry/fsbridge/fs.go @@ -49,7 +49,7 @@ func (f *fsFile) PathnameWithDeleted(ctx context.Context) string { // global there. return "" } - defer root.DecRef() + defer root.DecRef(ctx) name, _ := f.file.Dirent.FullName(root) return name @@ -87,8 +87,8 @@ func (f *fsFile) IncRef() { } // DecRef implements File. -func (f *fsFile) DecRef() { - f.file.DecRef() +func (f *fsFile) DecRef(ctx context.Context) { + f.file.DecRef(ctx) } // fsLookup implements Lookup interface using fs.File. @@ -124,7 +124,7 @@ func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptio if err != nil { return nil, err } - defer d.DecRef() + defer d.DecRef(ctx) if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) { return nil, syserror.ELOOP diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 89168220a..be0900030 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -43,7 +43,7 @@ func NewVFSFile(file *vfs.FileDescription) File { // PathnameWithDeleted implements File. func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string { root := vfs.RootFromContext(ctx) - defer root.DecRef() + defer root.DecRef(ctx) vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem() name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry()) @@ -86,8 +86,8 @@ func (f *VFSFile) IncRef() { } // DecRef implements File. -func (f *VFSFile) DecRef() { - f.file.DecRef() +func (f *VFSFile) DecRef(ctx context.Context) { + f.file.DecRef(ctx) } // FileDescription returns the FileDescription represented by f. It does not @@ -122,7 +122,7 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) // remainingTraversals is not configurable in VFS2, all callers are using the // default anyways. func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { - vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem() + vfsObj := l.root.Mount().Filesystem().VirtualFilesystem() creds := auth.CredentialsFromContext(ctx) path := fspath.Parse(pathname) pop := &vfs.PathOperation{ diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 93512c9b6..48e13613a 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -1,7 +1,19 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "root_inode_refs", + out = "root_inode_refs.go", + package = "devpts", + prefix = "rootInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "rootInode", + }, +) + go_library( name = "devpts", srcs = [ @@ -9,13 +21,18 @@ go_library( "line_discipline.go", "master.go", "queue.go", - "slave.go", + "replica.go", + "root_inode_refs.go", "terminal.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index e6fda2b4f..903135fae 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -35,6 +35,8 @@ import ( const Name = "devpts" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // Name implements vfs.FilesystemType.Name. @@ -58,6 +60,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -79,10 +82,11 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds // Construct the root directory. This is always inode id 1. root := &rootInode{ - slaves: make(map[uint32]*slaveInode), + replicas: make(map[uint32]*replicaInode), } root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + root.EnableLeakCheck() root.dentry.Init(root) // Construct the pts master inode and dentry. Linux always uses inode @@ -103,18 +107,22 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // rootInode is the root directory inode for the devpts mounts. +// +// +stateify savable type rootInode struct { + implStatFS kernfs.AlwaysValid kernfs.InodeAttrs kernfs.InodeDirectoryNoNewChildren kernfs.InodeNotSymlink kernfs.OrderedChildren + rootInodeRefs locks vfs.FileLocks @@ -128,10 +136,10 @@ type rootInode struct { root *rootInode // mu protects the fields below. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` - // slaves maps pty ids to slave inodes. - slaves map[uint32]*slaveInode + // replicas maps pty ids to replica inodes. + replicas map[uint32]*replicaInode // nextIdx is the next pty index to use. Must be accessed atomically. // @@ -151,22 +159,22 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) idx := i.nextIdx i.nextIdx++ - // Sanity check that slave with idx does not exist. - if _, ok := i.slaves[idx]; ok { + // Sanity check that replica with idx does not exist. + if _, ok := i.replicas[idx]; ok { panic(fmt.Sprintf("pty index collision; index %d already exists", idx)) } - // Create the new terminal and slave. + // Create the new terminal and replica. t := newTerminal(idx) - slave := &slaveInode{ + replica := &replicaInode{ root: i, t: t, } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) - slave.dentry.Init(slave) - i.slaves[idx] = slave + replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.dentry.Init(replica) + i.replicas[idx] = replica return t, nil } @@ -176,16 +184,18 @@ func (i *rootInode) masterClose(t *Terminal) { i.mu.Lock() defer i.mu.Unlock() - // Sanity check that slave with idx exists. - if _, ok := i.slaves[t.n]; !ok { + // Sanity check that replica with idx exists. + if _, ok := i.replicas[t.n]; !ok { panic(fmt.Sprintf("pty with index %d does not exist", t.n)) } - delete(i.slaves, t.n) + delete(i.replicas, t.n) } // Open implements kernfs.Inode.Open. -func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) +func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } @@ -193,16 +203,16 @@ func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D } // Lookup implements kernfs.Inode.Lookup. -func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (i *rootInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { idx, err := strconv.ParseUint(name, 10, 32) if err != nil { return nil, syserror.ENOENT } i.mu.Lock() defer i.mu.Unlock() - if si, ok := i.slaves[uint32(idx)]; ok { + if si, ok := i.replicas[uint32(idx)]; ok { si.dentry.IncRef() - return si.dentry.VFSDentry(), nil + return &si.dentry, nil } return nil, syserror.ENOENT @@ -212,8 +222,8 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() - ids := make([]int, 0, len(i.slaves)) - for id := range i.slaves { + ids := make([]int, 0, len(i.replicas)) + for id := range i.replicas { ids = append(ids, int(id)) } sort.Ints(ids) @@ -221,7 +231,7 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, dirent := vfs.Dirent{ Name: strconv.FormatUint(uint64(id), 10), Type: linux.DT_CHR, - Ino: i.slaves[uint32(id)].InodeAttrs.Ino(), + Ino: i.replicas[uint32(id)].InodeAttrs.Ino(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { @@ -231,3 +241,16 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, } return offset, nil } + +// DecRef implements kernfs.Inode.DecRef. +func (i *rootInode) DecRef(context.Context) { + i.rootInodeRefs.DecRef(i.Destroy) +} + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.DEVPTS_SUPER_MAGIC), nil +} diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go index b7c149047..448390cfe 100644 --- a/pkg/sentry/fsimpl/devpts/devpts_test.go +++ b/pkg/sentry/fsimpl/devpts/devpts_test.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) +func TestSimpleMasterToReplica(t *testing.T) { + ld := newLineDiscipline(linux.DefaultReplicaTermios) ctx := contexttest.Context(t) inBytes := []byte("hello, tty\n") src := usermem.BytesIOSequence(inBytes) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index f7bc325d1..e6b0e81cf 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -41,7 +42,7 @@ const ( ) // lineDiscipline dictates how input and output are handled between the -// pseudoterminal (pty) master and slave. It can be configured to alter I/O, +// pseudoterminal (pty) master and replica. It can be configured to alter I/O, // modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man // pages are good resources for how to affect the line discipline: // @@ -52,8 +53,8 @@ const ( // // lineDiscipline has a simple structure but supports a multitude of options // (see the above man pages). It consists of two queues of bytes: one from the -// terminal master to slave (the input queue) and one from slave to master (the -// output queue). When bytes are written to one end of the pty, the line +// terminal master to replica (the input queue) and one from replica to master +// (the output queue). When bytes are written to one end of the pty, the line // discipline reads the bytes, modifies them or takes special action if // required, and enqueues them to be read by the other end of the pty: // @@ -62,7 +63,7 @@ const ( // | (inputQueueWrite) +-------------+ (inputQueueRead) | // | | // | v -// masterFD slaveFD +// masterFD replicaFD // ^ | // | | // | output to terminal +--------------+ output from process | @@ -101,8 +102,8 @@ type lineDiscipline struct { // masterWaiter is used to wait on the master end of the TTY. masterWaiter waiter.Queue `state:"zerovalue"` - // slaveWaiter is used to wait on the slave end of the TTY. - slaveWaiter waiter.Queue `state:"zerovalue"` + // replicaWaiter is used to wait on the replica end of the TTY. + replicaWaiter waiter.Queue `state:"zerovalue"` } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { @@ -113,27 +114,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { } // getTermios gets the linux.Termios for the tty. -func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.RLock() defer l.termiosMu.RUnlock() // We must copy a Termios struct, not KernelTermios. t := l.termios.ToTermios() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyOut(task, args[2].Pointer()) return 0, err } // setTermios sets a linux.Termios for the tty. -func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.Lock() defer l.termiosMu.Unlock() oldCanonEnabled := l.termios.LEnabled(linux.ICANON) // We must copy a Termios struct, not KernelTermios. var t linux.Termios - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyIn(task, args[2].Pointer()) l.termios.FromTermios(t) // If canonical mode is turned off, move bytes from inQueue's wait @@ -144,27 +141,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc l.inQueue.pushWaitBufLocked(l) l.inQueue.readable = true l.inQueue.mu.Unlock() - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return 0, err } -func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyOut(t, args[2].Pointer()) return err } -func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyIn(t, args[2].Pointer()) return err } @@ -174,14 +167,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask { return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios) } -func (l *lineDiscipline) slaveReadiness() waiter.EventMask { +func (l *lineDiscipline) replicaReadiness() waiter.EventMask { l.termiosMu.RLock() defer l.termiosMu.RUnlock() return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios) } -func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.inQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { + return l.inQueue.readableSize(t, io, args) } func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -194,7 +187,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque if n > 0 { l.masterWaiter.Notify(waiter.EventOut) if pushed { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return n, nil } @@ -209,14 +202,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) return n, nil } return 0, syserror.ErrWouldBlock } -func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.outQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { + return l.outQueue.readableSize(t, io, args) } func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -227,7 +220,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventOut) + l.replicaWaiter.Notify(waiter.EventOut) if pushed { l.masterWaiter.Notify(waiter.EventIn) } diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 69879498a..69c2fe951 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -17,9 +17,11 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -29,7 +31,10 @@ import ( ) // masterInode is the inode for the master end of the Terminal. +// +// +stateify savable type masterInode struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeNotDirectory @@ -47,28 +52,26 @@ type masterInode struct { var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. -func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { t, err := mi.root.allocateTerminal(rp.Credentials()) if err != nil { return nil, err } - mi.IncRef() fd := &masterFileDescription{ inode: mi, t: t, } fd.LockFD.Init(&mi.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { - mi.DecRef() + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil } // Stat implements kernfs.Inode.Stat. -func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := mi.InodeAttrs.Stat(vfsfs, opts) +func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -86,6 +89,7 @@ func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) } +// +stateify savable type masterFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -98,9 +102,8 @@ type masterFileDescription struct { var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil) // Release implements vfs.FileDescriptionImpl.Release. -func (mfd *masterFileDescription) Release() { +func (mfd *masterFileDescription) Release(ctx context.Context) { mfd.inode.root.masterClose(mfd.t) - mfd.inode.DecRef() } // EventRegister implements waiter.Waitable.EventRegister. @@ -130,46 +133,51 @@ func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSeque // Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the output queue read buffer. - return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args) + return 0, mfd.t.ld.outputQueueReadSize(t, io, args) case linux.TCGETS: // N.B. TCGETS on the master actually returns the configuration - // of the slave end. - return mfd.t.ld.getTermios(ctx, io, args) + // of the replica end. + return mfd.t.ld.getTermios(t, args) case linux.TCSETS: // N.B. TCSETS on the master actually affects the configuration - // of the slave end. - return mfd.t.ld.setTermios(ctx, io, args) + // of the replica end. + return mfd.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return mfd.t.ld.setTermios(ctx, io, args) + return mfd.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(mfd.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCSPTLCK: // TODO(b/29356795): Implement pty locking. For now just pretend we do. return 0, nil case linux.TIOCGWINSZ: - return 0, mfd.t.ld.windowSize(ctx, io, args) + return 0, mfd.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, mfd.t.ld.setWindowSize(ctx, io, args) + return 0, mfd.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mfd.t.setControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mfd.t.releaseControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + return mfd.t.foregroundProcessGroup(ctx, args, true /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) + return mfd.t.setForegroundProcessGroup(ctx, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY @@ -186,7 +194,7 @@ func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatO // Stat implements vfs.FileDescriptionImpl.Stat. func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem() - return mfd.inode.Stat(fs, opts) + return mfd.inode.Stat(ctx, fs, opts) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go index dffb4232c..55bff3e60 100644 --- a/pkg/sentry/fsimpl/devpts/queue.go +++ b/pkg/sentry/fsimpl/devpts/queue.go @@ -17,8 +17,10 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -30,7 +32,7 @@ import ( const waitBufMaxBytes = 131072 // queue represents one of the input or output queues between a pty master and -// slave. Bytes written to a queue are added to the read buffer until it is +// replica. Bytes written to a queue are added to the read buffer until it is // full, at which point they are written to the wait buffer. Bytes are // processed (i.e. undergo termios transformations) as they are added to the // read buffer. The read buffer is readable when its length is nonzero and @@ -83,17 +85,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask { } // readableSize writes the number of readable bytes to userspace. -func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (q *queue) readableSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { q.mu.Lock() defer q.mu.Unlock() - var size int32 + size := primitive.Int32(0) if q.readable { - size = int32(len(q.readBuf)) + size = primitive.Int32(len(q.readBuf)) } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := size.CopyOut(t, args[2].Pointer()) return err } @@ -102,8 +102,7 @@ func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.Sysca // as whether the read caused more readable data to become available (whether // data was pushed from the wait buffer to the read buffer). // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) { q.mu.Lock() defer q.mu.Unlock() @@ -143,8 +142,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl // write writes to q from userspace. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) { q.mu.Lock() defer q.mu.Unlock() @@ -186,8 +184,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip // writeBytes writes to q from b. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) writeBytes(b []byte, l *lineDiscipline) { q.mu.Lock() defer q.mu.Unlock() diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go new file mode 100644 index 000000000..6515c5536 --- /dev/null +++ b/pkg/sentry/fsimpl/devpts/replica.go @@ -0,0 +1,204 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devpts + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// replicaInode is the inode for the replica end of the Terminal. +// +// +stateify savable +type replicaInode struct { + implStatFS + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + locks vfs.FileLocks + + // Keep a reference to this inode's dentry. + dentry kernfs.Dentry + + // root is the devpts root inode. + root *rootInode + + // t is the connected Terminal. + t *Terminal +} + +var _ kernfs.Inode = (*replicaInode)(nil) + +// Open implements kernfs.Inode.Open. +func (ri *replicaInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &replicaFileDescription{ + inode: ri, + } + fd.LockFD.Init(&ri.locks) + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return &fd.vfsfd, nil + +} + +// Valid implements kernfs.Inode.Valid. +func (ri *replicaInode) Valid(context.Context) bool { + // Return valid if the replica still exists. + ri.root.mu.Lock() + defer ri.root.mu.Unlock() + _, ok := ri.root.replicas[ri.t.n] + return ok +} + +// Stat implements kernfs.Inode.Stat. +func (ri *replicaInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := ri.InodeAttrs.Stat(ctx, vfsfs, opts) + if err != nil { + return linux.Statx{}, err + } + statx.Blksize = 1024 + statx.RdevMajor = linux.UNIX98_PTY_REPLICA_MAJOR + statx.RdevMinor = ri.t.n + return statx, nil +} + +// SetStat implements kernfs.Inode.SetStat +func (ri *replicaInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + if opts.Stat.Mask&linux.STATX_SIZE != 0 { + return syserror.EINVAL + } + return ri.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) +} + +// +stateify savable +type replicaFileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *replicaInode +} + +var _ vfs.FileDescriptionImpl = (*replicaFileDescription)(nil) + +// Release implements fs.FileOperations.Release. +func (rfd *replicaFileDescription) Release(ctx context.Context) {} + +// EventRegister implements waiter.Waitable.EventRegister. +func (rfd *replicaFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + rfd.inode.t.ld.replicaWaiter.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (rfd *replicaFileDescription) EventUnregister(e *waiter.Entry) { + rfd.inode.t.ld.replicaWaiter.EventUnregister(e) +} + +// Readiness implements waiter.Waitable.Readiness. +func (rfd *replicaFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { + return rfd.inode.t.ld.replicaReadiness() +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (rfd *replicaFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { + return rfd.inode.t.ld.inputQueueRead(ctx, dst) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (rfd *replicaFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { + return rfd.inode.t.ld.outputQueueWrite(ctx, src) +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (rfd *replicaFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + + switch cmd := args[1].Uint(); cmd { + case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ + // Get the number of bytes in the input queue read buffer. + return 0, rfd.inode.t.ld.inputQueueReadSize(t, io, args) + case linux.TCGETS: + return rfd.inode.t.ld.getTermios(t, args) + case linux.TCSETS: + return rfd.inode.t.ld.setTermios(t, args) + case linux.TCSETSW: + // TODO(b/29356795): This should drain the output queue first. + return rfd.inode.t.ld.setTermios(t, args) + case linux.TIOCGPTN: + nP := primitive.Uint32(rfd.inode.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) + return 0, err + case linux.TIOCGWINSZ: + return 0, rfd.inode.t.ld.windowSize(t, args) + case linux.TIOCSWINSZ: + return 0, rfd.inode.t.ld.setWindowSize(t, args) + case linux.TIOCSCTTY: + // Make the given terminal the controlling terminal of the + // calling process. + return 0, rfd.inode.t.setControllingTTY(ctx, args, false /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, rfd.inode.t.releaseControllingTTY(ctx, args, false /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return rfd.inode.t.foregroundProcessGroup(ctx, args, false /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return rfd.inode.t.setForegroundProcessGroup(ctx, args, false /* isMaster */) + default: + maybeEmitUnimplementedEvent(ctx, cmd) + return 0, syserror.ENOTTY + } +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (rfd *replicaFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + creds := auth.CredentialsFromContext(ctx) + fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem() + return rfd.inode.SetStat(ctx, fs, creds, opts) +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (rfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem() + return rfd.inode.Stat(ctx, fs, opts) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (rfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return rfd.Locks().LockPOSIX(ctx, &rfd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (rfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return rfd.Locks().UnlockPOSIX(ctx, &rfd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go deleted file mode 100644 index cf1a0f0ac..000000000 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devpts - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -// slaveInode is the inode for the slave end of the Terminal. -type slaveInode struct { - kernfs.InodeAttrs - kernfs.InodeNoopRefCount - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - - locks vfs.FileLocks - - // Keep a reference to this inode's dentry. - dentry kernfs.Dentry - - // root is the devpts root inode. - root *rootInode - - // t is the connected Terminal. - t *Terminal -} - -var _ kernfs.Inode = (*slaveInode)(nil) - -// Open implements kernfs.Inode.Open. -func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - si.IncRef() - fd := &slaveFileDescription{ - inode: si, - } - fd.LockFD.Init(&si.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { - si.DecRef() - return nil, err - } - return &fd.vfsfd, nil - -} - -// Valid implements kernfs.Inode.Valid. -func (si *slaveInode) Valid(context.Context) bool { - // Return valid if the slave still exists. - si.root.mu.Lock() - defer si.root.mu.Unlock() - _, ok := si.root.slaves[si.t.n] - return ok -} - -// Stat implements kernfs.Inode.Stat. -func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := si.InodeAttrs.Stat(vfsfs, opts) - if err != nil { - return linux.Statx{}, err - } - statx.Blksize = 1024 - statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR - statx.RdevMinor = si.t.n - return statx, nil -} - -// SetStat implements kernfs.Inode.SetStat -func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - if opts.Stat.Mask&linux.STATX_SIZE != 0 { - return syserror.EINVAL - } - return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) -} - -type slaveFileDescription struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.LockFD - - inode *slaveInode -} - -var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil) - -// Release implements fs.FileOperations.Release. -func (sfd *slaveFileDescription) Release() { - sfd.inode.DecRef() -} - -// EventRegister implements waiter.Waitable.EventRegister. -func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask) -} - -// EventUnregister implements waiter.Waitable.EventUnregister. -func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) { - sfd.inode.t.ld.slaveWaiter.EventUnregister(e) -} - -// Readiness implements waiter.Waitable.Readiness. -func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { - return sfd.inode.t.ld.slaveReadiness() -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { - return sfd.inode.t.ld.inputQueueRead(ctx, dst) -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { - return sfd.inode.t.ld.outputQueueWrite(ctx, src) -} - -// Ioctl implements vfs.FileDescripionImpl.Ioctl. -func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - switch cmd := args[1].Uint(); cmd { - case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ - // Get the number of bytes in the input queue read buffer. - return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args) - case linux.TCGETS: - return sfd.inode.t.ld.getTermios(ctx, io, args) - case linux.TCSETS: - return sfd.inode.t.ld.setTermios(ctx, io, args) - case linux.TCSETSW: - // TODO(b/29356795): This should drain the output queue first. - return sfd.inode.t.ld.setTermios(ctx, io, args) - case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err - case linux.TIOCGWINSZ: - return 0, sfd.inode.t.ld.windowSize(ctx, io, args) - case linux.TIOCSWINSZ: - return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args) - case linux.TIOCSCTTY: - // Make the given terminal the controlling terminal of the - // calling process. - return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */) - case linux.TIOCNOTTY: - // Release this process's controlling terminal. - return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) - case linux.TIOCGPGRP: - // Get the foreground process group. - return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) - case linux.TIOCSPGRP: - // Set the foreground process group. - return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) - default: - maybeEmitUnimplementedEvent(ctx, cmd) - return 0, syserror.ENOTTY - } -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - creds := auth.CredentialsFromContext(ctx) - fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.SetStat(ctx, fs, creds, opts) -} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.Stat(fs, opts) -} - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go index 7d2781c54..510bd6d89 100644 --- a/pkg/sentry/fsimpl/devpts/terminal.go +++ b/pkg/sentry/fsimpl/devpts/terminal.go @@ -17,9 +17,9 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" ) // Terminal is a pseudoterminal. @@ -36,25 +36,25 @@ type Terminal struct { // this terminal. This field is immutable. masterKTTY *kernel.TTY - // slaveKTTY contains the controlling process of the slave end of this + // replicaKTTY contains the controlling process of the replica end of this // terminal. This field is immutable. - slaveKTTY *kernel.TTY + replicaKTTY *kernel.TTY } func newTerminal(n uint32) *Terminal { - termios := linux.DefaultSlaveTermios + termios := linux.DefaultReplicaTermios t := Terminal{ - n: n, - ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{Index: n}, - slaveKTTY: &kernel.TTY{Index: n}, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + replicaKTTY: &kernel.TTY{Index: n}, } return &t } // setControllingTTY makes tm the controlling terminal of the calling thread // group. -func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("setControllingTTY must be called from a task context") @@ -65,7 +65,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a // releaseControllingTTY removes tm as the controlling terminal of the calling // thread group. -func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("releaseControllingTTY must be called from a task context") @@ -75,7 +75,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar } // foregroundProcessGroup gets the process group ID of tm's foreground process. -func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("foregroundProcessGroup must be called from a task context") @@ -87,24 +87,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a } // Write it out to *arg. - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ - AddressSpaceActive: true, - }) + retP := primitive.Int32(ret) + _, err = retP.CopyOut(task, args[2].Pointer()) return 0, err } // foregroundProcessGroup sets tm's foreground process. -func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("setForegroundProcessGroup must be called from a task context") } // Read in the process group ID. - var pgid int32 - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgid primitive.Int32 + if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } @@ -116,5 +113,5 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { if isMaster { return tm.masterKTTY } - return tm.slaveKTTY + return tm.replicaKTTY } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index aa0c2ad8c..01bbee5ad 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -24,6 +24,7 @@ go_test( library = ":devtmpfs", deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/tmpfs", diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index d0e06cdc0..6d1753080 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -18,6 +18,7 @@ package devtmpfs import ( "fmt" + "path" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -32,8 +33,10 @@ import ( const Name = "devtmpfs" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct { - initOnce sync.Once + initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1664): not yet supported. initErr error // fs is the tmpfs filesystem that backs all mounts of this FilesystemType. @@ -79,7 +82,7 @@ type Accessor struct { // NewAccessor returns an Accessor that supports creation of device special // files in the devtmpfs instance registered with name fsTypeName in vfsObj. func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) { - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.MountOptions{}) if err != nil { return nil, err } @@ -92,9 +95,9 @@ func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth } // Release must be called when a is no longer in use. -func (a *Accessor) Release() { - a.root.DecRef() - a.mntns.DecRef() +func (a *Accessor) Release(ctx context.Context) { + a.root.DecRef(ctx) + a.mntns.DecRef(ctx) } // accessorContext implements context.Context by extending an existing @@ -150,13 +153,11 @@ func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind v // Create any parent directories. See // devtmpfs.c:handle_create()=>path_create(). - for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() { - pop := a.pathOperationAt(it.String()) - if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - return fmt.Errorf("failed to create directory %q: %v", it.String(), err) - } + parent := path.Dir(pathname) + if err := a.vfsObj.MkdirAllAt(ctx, parent, a.root, a.creds, &vfs.MkdirOptions{ + Mode: 0755, + }); err != nil { + return fmt.Errorf("failed to create device parent directory %q: %v", parent, err) } // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go index b6d52c015..3a38b8bb4 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go @@ -15,9 +15,11 @@ package devtmpfs import ( + "path" "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" @@ -25,12 +27,15 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" ) -func TestDevtmpfs(t *testing.T) { +const devPath = "/dev" + +func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry, func()) { + t.Helper() + ctx := contexttest.Context(t) creds := auth.CredentialsFromContext(ctx) - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } // Register tmpfs just so that we can have a root filesystem that isn't @@ -43,14 +48,11 @@ func TestDevtmpfs(t *testing.T) { }) // Create a test mount namespace with devtmpfs mounted at "/dev". - const devPath = "/dev" - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.MountOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef() root := mntns.Root() - defer root.DecRef() devpop := vfs.PathOperation{ Root: root, Start: root, @@ -61,62 +63,167 @@ func TestDevtmpfs(t *testing.T) { }); err != nil { t.Fatalf("failed to create mount point: %v", err) } - if err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil { + if _, err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil { t.Fatalf("failed to mount devtmpfs: %v", err) } + return ctx, creds, vfsObj, root, func() { + root.DecRef(ctx) + mntns.DecRef(ctx) + } +} + +func TestUserspaceInit(t *testing.T) { + ctx, creds, vfsObj, root, cleanup := setupDevtmpfs(t) + defer cleanup() + a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") if err != nil { t.Fatalf("failed to create devtmpfs.Accessor: %v", err) } - defer a.Release() + defer a.Release(ctx) // Create "userspace-initialized" files using a devtmpfs.Accessor. if err := a.UserspaceInit(ctx); err != nil { t.Fatalf("failed to userspace-initialize devtmpfs: %v", err) } + // Created files should be visible in the test mount namespace. - abspath := devPath + "/fd" - target, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }) - if want := "/proc/self/fd"; err != nil || target != want { - t.Fatalf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, target, err, want) + links := []struct { + source string + target string + }{ + { + source: "fd", + target: "/proc/self/fd", + }, + { + source: "stdin", + target: "/proc/self/fd/0", + }, + { + source: "stdout", + target: "/proc/self/fd/1", + }, + { + source: "stderr", + target: "/proc/self/fd/2", + }, + { + source: "ptmx", + target: "pts/ptmx", + }, } - // Create a dummy device special file using a devtmpfs.Accessor. - const ( - pathInDev = "dummy" - kind = vfs.CharDevice - major = 12 - minor = 34 - perms = 0600 - wantMode = linux.S_IFCHR | perms - ) - if err := a.CreateDeviceFile(ctx, pathInDev, kind, major, minor, perms); err != nil { - t.Fatalf("failed to create device file: %v", err) + for _, link := range links { + abspath := path.Join(devPath, link.source) + if gotTarget, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(abspath), + }); err != nil || gotTarget != link.target { + t.Errorf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, gotTarget, err, link.target) + } } - // The device special file should be visible in the test mount namespace. - abspath = devPath + "/" + pathInDev - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }, &vfs.StatOptions{ - Mask: linux.STATX_TYPE | linux.STATX_MODE, - }) - if err != nil { - t.Fatalf("failed to stat device file at %q: %v", abspath, err) + + dirs := []string{"shm", "pts"} + for _, dir := range dirs { + abspath := path.Join(devPath, dir) + statx, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(abspath), + }, &vfs.StatOptions{ + Mask: linux.STATX_MODE, + }) + if err != nil { + t.Errorf("stat(%q): got error %v ", abspath, err) + continue + } + if want := uint16(0755) | linux.S_IFDIR; statx.Mode != want { + t.Errorf("stat(%q): got mode %x, want %x", abspath, statx.Mode, want) + } } - if stat.Mode != wantMode { - t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode) +} + +func TestCreateDeviceFile(t *testing.T) { + ctx, creds, vfsObj, root, cleanup := setupDevtmpfs(t) + defer cleanup() + + a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") + if err != nil { + t.Fatalf("failed to create devtmpfs.Accessor: %v", err) } - if stat.RdevMajor != major { - t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, major) + defer a.Release(ctx) + + devFiles := []struct { + path string + kind vfs.DeviceKind + major uint32 + minor uint32 + perms uint16 + }{ + { + path: "dummy", + kind: vfs.CharDevice, + major: 12, + minor: 34, + perms: 0600, + }, + { + path: "foo/bar", + kind: vfs.BlockDevice, + major: 13, + minor: 35, + perms: 0660, + }, + { + path: "foo/baz", + kind: vfs.CharDevice, + major: 12, + minor: 40, + perms: 0666, + }, + { + path: "a/b/c/d/e", + kind: vfs.BlockDevice, + major: 12, + minor: 34, + perms: 0600, + }, } - if stat.RdevMinor != minor { - t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, minor) + + for _, f := range devFiles { + if err := a.CreateDeviceFile(ctx, f.path, f.kind, f.major, f.minor, f.perms); err != nil { + t.Fatalf("failed to create device file: %v", err) + } + // The device special file should be visible in the test mount namespace. + abspath := path.Join(devPath, f.path) + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(abspath), + }, &vfs.StatOptions{ + Mask: linux.STATX_TYPE | linux.STATX_MODE, + }) + if err != nil { + t.Fatalf("failed to stat device file at %q: %v", abspath, err) + } + if stat.RdevMajor != f.major { + t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, f.major) + } + if stat.RdevMinor != f.minor { + t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, f.minor) + } + wantMode := f.perms + switch f.kind { + case vfs.CharDevice: + wantMode |= linux.S_IFCHR + case vfs.BlockDevice: + wantMode |= linux.S_IFBLK + } + if stat.Mode != wantMode { + t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode) + } } } diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index d12d78b84..1c27ad700 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -30,9 +30,11 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// EventFileDescription implements FileDescriptionImpl for file-based event +// EventFileDescription implements vfs.FileDescriptionImpl for file-based event // notification (eventfd). Eventfds are usually internal to the Sentry but in // certain situations they may be converted into a host-backed eventfd. +// +// +stateify savable type EventFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -59,9 +61,9 @@ type EventFileDescription struct { var _ vfs.FileDescriptionImpl = (*EventFileDescription)(nil) // New creates a new event fd. -func New(vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) { +func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) { vd := vfsObj.NewAnonVirtualDentry("[eventfd]") - defer vd.DecRef() + defer vd.DecRef(ctx) efd := &EventFileDescription{ val: initVal, semMode: semMode, @@ -106,8 +108,8 @@ func (efd *EventFileDescription) HostFD() (int, error) { return efd.hostfd, nil } -// Release implements FileDescriptionImpl.Release() -func (efd *EventFileDescription) Release() { +// Release implements vfs.FileDescriptionImpl.Release. +func (efd *EventFileDescription) Release(context.Context) { efd.mu.Lock() defer efd.mu.Unlock() if efd.hostfd >= 0 { @@ -119,7 +121,7 @@ func (efd *EventFileDescription) Release() { } } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { if dst.NumBytes() < 8 { return 0, syscall.EINVAL @@ -130,7 +132,7 @@ func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequenc return 8, nil } -// Write implements FileDescriptionImpl.Write. +// Write implements vfs.FileDescriptionImpl.Write. func (efd *EventFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { if src.NumBytes() < 8 { return 0, syscall.EINVAL diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_test.go b/pkg/sentry/fsimpl/eventfd/eventfd_test.go index 20e3adffc..49916fa81 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd_test.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd_test.go @@ -36,16 +36,16 @@ func TestEventFD(t *testing.T) { for _, initVal := range initVals { ctx := contexttest.Context(t) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } // Make a new eventfd that is writable. - eventfd, err := New(vfsObj, initVal, false, linux.O_RDWR) + eventfd, err := New(ctx, vfsObj, initVal, false, linux.O_RDWR) if err != nil { t.Fatalf("New() failed: %v", err) } - defer eventfd.DecRef() + defer eventfd.DecRef(ctx) // Register a callback for a write event. w, ch := waiter.NewChannelEntry(nil) @@ -74,16 +74,16 @@ func TestEventFD(t *testing.T) { func TestEventFDStat(t *testing.T) { ctx := contexttest.Context(t) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } // Make a new eventfd that is writable. - eventfd, err := New(vfsObj, 0, false, linux.O_RDWR) + eventfd, err := New(ctx, vfsObj, 0, false, linux.O_RDWR) if err != nil { t.Fatalf("New() failed: %v", err) } - defer eventfd.DecRef() + defer eventfd.DecRef(ctx) statx, err := eventfd.Stat(ctx, vfs.StatOptions{ Mask: linux.STATX_BASIC_STATS, diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index ef24f8159..7b1eec3da 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/fd", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", @@ -86,9 +88,9 @@ go_test( library = ":ext", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fspath", + "//pkg/marshal/primitive", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", @@ -96,7 +98,7 @@ go_test( "//pkg/syserror", "//pkg/test/testutil", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 89caee3df..c349b886e 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -53,13 +53,17 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { return nil, nil, nil, nil, err } vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -68,7 +72,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys root := mntns.Root() tearDown := func() { - root.DecRef() + root.DecRef(ctx) if err := f.Close(); err != nil { b.Fatalf("tearDown failed: %v", err) @@ -90,7 +94,7 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf ctx := contexttest.Context(b) creds := auth.CredentialsFromContext(ctx) - if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ + if _, err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: int(f.Fd()), }, @@ -169,7 +173,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount point: %v", err) } - defer mountPoint.DecRef() + defer mountPoint.DecRef(ctx) // Create extfs submount. mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop) diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index 8bb104ff0..1165234f9 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -18,7 +18,7 @@ import ( "io" "math" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/syserror" ) @@ -34,19 +34,19 @@ type blockMapFile struct { // directBlks are the direct blocks numbers. The physical blocks pointed by // these holds file data. Contains file blocks 0 to 11. - directBlks [numDirectBlks]uint32 + directBlks [numDirectBlks]primitive.Uint32 // indirectBlk is the physical block which contains (blkSize/4) direct block // numbers (as uint32 integers). - indirectBlk uint32 + indirectBlk primitive.Uint32 // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect // block numbers (as uint32 integers). - doubleIndirectBlk uint32 + doubleIndirectBlk primitive.Uint32 // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly // indirect block numbers (as uint32 integers). - tripleIndirectBlk uint32 + tripleIndirectBlk primitive.Uint32 // coverage at (i)th index indicates the amount of file data a node at // height (i) covers. Height 0 is the direct block. @@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) { } blkMap := file.regFile.inode.diskInode.Data() - binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks) - binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk) + for i := 0; i < numDirectBlks; i++ { + file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4]) + } + file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4]) + file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4]) + file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4]) return file, nil } @@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { switch { case offset < dirBlksEnd: // Direct block. - curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:]) + curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:]) case offset < indirBlkEnd: // Indirect block. - curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:]) + curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:]) case offset < doubIndirBlkEnd: // Doubly indirect block. - curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:]) + curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:]) default: // Triply indirect block. - curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:]) + curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:]) } read += curR @@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds read := 0 curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { - var childPhyBlk uint32 + var childPhyBlk primitive.Uint32 err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } - n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:]) + n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:]) read += n if err != nil { return read, err diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 6fa84e7aa..ed98b482e 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -20,7 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) var fileData []byte blkNums := newBlkNumGen() - var data []byte + off := 0 + data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes()) // Write the direct blocks. for i := 0; i < numDirectBlks; i++ { - curBlkNum := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, curBlkNum) - fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(data[off:]) + off += curBlkNum.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...) } // Write to indirect block. - indirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, indirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...) - - // Write to indirect block. - doublyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...) - - // Write to indirect block. - triplyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...) + indirectBlk := primitive.Uint32(blkNums.next()) + indirectBlk.MarshalBytes(data[off:]) + off += indirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...) + + // Write to double indirect block. + doublyIndirectBlk := primitive.Uint32(blkNums.next()) + doublyIndirectBlk.MarshalBytes(data[off:]) + off += doublyIndirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...) + + // Write to triple indirect block. + triplyIndirectBlk := primitive.Uint32(blkNums.next()) + triplyIndirectBlk.MarshalBytes(data[off:]) + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...) args := inodeArgs{ fs: &filesystem{ @@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN var fileData []byte for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 { - curBlkNum := blkNums.next() - copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum)) - fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(disk[off : off+4]) + fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...) } return fileData } diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 55902322a..9bfed883a 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -15,10 +15,13 @@ package ext import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/vfs" ) // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -55,7 +58,7 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { +func (d *dentry) DecRef(ctx context.Context) { // FIXME(b/134676337): filesystem.mu may not be locked as required by // inode.decRef(). d.inode.decRef() @@ -64,7 +67,7 @@ func (d *dentry) DecRef() { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // TODO(b/134676337): Implement inotify. -func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. // @@ -76,4 +79,4 @@ func (d *dentry) Watches() *vfs.Watches { // OnZeroWatches implements vfs.Dentry.OnZeroWatches. // // TODO(b/134676337): Implement inotify. -func (d *dentry) OnZeroWatches() {} +func (d *dentry) OnZeroWatches(context.Context) {} diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 357512c7e..0ad79b381 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -16,7 +16,6 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -28,6 +27,8 @@ import ( ) // directory represents a directory inode. It holds the childList in memory. +// +// +stateify savable type directory struct { inode inode @@ -39,7 +40,7 @@ type directory struct { // Lock Order (outermost locks must be taken first): // directory.mu // filesystem.mu - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // childList is a list containing (1) child dirents and (2) fake dirents // (with diskDirent == nil) that represent the iteration position of @@ -98,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) { } else { curDirent.diskDirent = &disklayout.DirentOld{} } - binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent) + curDirent.diskDirent.UnmarshalBytes(buf) if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 { // Inode number and name length fields being set to 0 is used to indicate @@ -120,6 +121,8 @@ func (i *inode) isDir() bool { } // dirent is the directory.childList node. +// +// +stateify savable type dirent struct { diskDirent disklayout.Dirent @@ -129,6 +132,8 @@ type dirent struct { // directoryFD represents a directory file description. It implements // vfs.FileDescriptionImpl. +// +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl @@ -142,7 +147,7 @@ type directoryFD struct { var _ vfs.FileDescriptionImpl = (*directoryFD)(nil) // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(ctx context.Context) { if fd.iter == nil { return } diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 9bd9c76c0..d98a05dd8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -22,10 +22,11 @@ go_library( "superblock_old.go", "test_utils.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/marshal", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go index ad6f4fef8..0d56ae9da 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // BlockGroup represents a Linux ext block group descriptor. An ext file system // is split into a series of block groups. This provides an access layer to // information needed to access and use a block group. @@ -30,6 +34,8 @@ package disklayout // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors. type BlockGroup interface { + marshal.Marshallable + // InodeTable returns the absolute block number of the block containing the // inode table. This points to an array of Inode structs. Inode tables are // statically allocated at mkfs time. The superblock records the number of diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go index 3e16c76db..a35fa22a0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go @@ -17,6 +17,8 @@ package disklayout // BlockGroup32Bit emulates the first half of struct ext4_group_desc in // fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and // 32-bit ext4 filesystems. It implements BlockGroup interface. +// +// +marshal type BlockGroup32Bit struct { BlockBitmapLo uint32 InodeBitmapLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go index 9a809197a..d54d1d345 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go @@ -18,6 +18,8 @@ package disklayout // It is the block group descriptor struct for 64-bit ext4 filesystems. // It implements BlockGroup interface. It is an extension of the 32-bit // version of BlockGroup. +// +// +marshal type BlockGroup64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go index 0ef4294c0..e4ce484e4 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go @@ -21,6 +21,8 @@ import ( // TestBlockGroupSize tests that the block group descriptor structs are of the // correct size. func TestBlockGroupSize(t *testing.T) { - assertSize(t, BlockGroup32Bit{}, 32) - assertSize(t, BlockGroup64Bit{}, 64) + var bgSmall BlockGroup32Bit + assertSize(t, &bgSmall, 32) + var bgBig BlockGroup64Bit + assertSize(t, &bgBig, 64) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go index 417b6cf65..568c8cb4c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go @@ -15,6 +15,7 @@ package disklayout import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fs" ) @@ -51,6 +52,8 @@ var ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories. type Dirent interface { + marshal.Marshallable + // Inode returns the absolute inode number of the underlying inode. // Inode number 0 signifies an unused dirent. Inode() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go index 29ae4a5c2..51f9c2946 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go @@ -29,12 +29,14 @@ import ( // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentNew struct { InodeNumber uint32 RecordLength uint16 NameLength uint8 FileTypeRaw uint8 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentNew implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go index 6fff12a6e..d4b19e086 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go @@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs" // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentOld struct { InodeNumber uint32 RecordLength uint16 NameLength uint16 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentOld implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go index 934919f8a..3486864dc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go @@ -21,6 +21,8 @@ import ( // TestDirentSize tests that the dirent structs are of the correct // size. func TestDirentSize(t *testing.T) { - assertSize(t, DirentOld{}, uintptr(DirentSize)) - assertSize(t, DirentNew{}, uintptr(DirentSize)) + var dOld DirentOld + assertSize(t, &dOld, DirentSize) + var dNew DirentNew + assertSize(t, &dNew, DirentSize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go index bdf4e2132..0834e9ba8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go +++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go @@ -36,8 +36,6 @@ // escape analysis on an unknown implementation at compile time. // // Notes: -// - All fields in these structs are exported because binary.Read would -// panic otherwise. // - All structures on disk are in little-endian order. Only jbd2 (journal) // structures are in big-endian order. // - All OS dependent fields in these structures will be interpretted using diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go index 4110649ab..b13999bfc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // Extents were introduced in ext4 and provide huge performance gains in terms // data locality and reduced metadata block usage. Extents are organized in // extent trees. The root node is contained in inode.BlocksRaw. @@ -64,6 +68,8 @@ type ExtentNode struct { // ExtentEntry represents an extent tree node entry. The entry can either be // an ExtentIdx or Extent itself. This exists to simplify navigation logic. type ExtentEntry interface { + marshal.Marshallable + // FileBlock returns the first file block number covered by this entry. FileBlock() uint32 @@ -75,6 +81,8 @@ type ExtentEntry interface { // tree node begins with this and is followed by `NumEntries` number of: // - Extent if `Depth` == 0 // - ExtentIdx otherwise +// +// +marshal type ExtentHeader struct { // Magic in the extent magic number, must be 0xf30a. Magic uint16 @@ -96,6 +104,8 @@ type ExtentHeader struct { // internal nodes. Sorted in ascending order based on FirstFileBlock since // Linux does a binary search on this. This points to a block containing the // child node. +// +// +marshal type ExtentIdx struct { FirstFileBlock uint32 ChildBlockLo uint32 @@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 { // nodes. Sorted in ascending order based on FirstFileBlock since Linux does a // binary search on this. This points to an array of data blocks containing the // file data. It covers `Length` data blocks starting from `StartBlock`. +// +// +marshal type Extent struct { FirstFileBlock uint32 Length uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go index 8762b90db..c96002e19 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go @@ -21,7 +21,10 @@ import ( // TestExtentSize tests that the extent structs are of the correct // size. func TestExtentSize(t *testing.T) { - assertSize(t, ExtentHeader{}, ExtentHeaderSize) - assertSize(t, ExtentIdx{}, ExtentEntrySize) - assertSize(t, Extent{}, ExtentEntrySize) + var h ExtentHeader + assertSize(t, &h, ExtentHeaderSize) + var i ExtentIdx + assertSize(t, &i, ExtentEntrySize) + var e Extent + assertSize(t, &e, ExtentEntrySize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go index 88ae913f5..ef25040a9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go @@ -16,6 +16,7 @@ package disklayout import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) @@ -38,6 +39,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes. type Inode interface { + marshal.Marshallable + // Mode returns the linux file mode which is majorly used to extract // information like: // - File permissions (read/write/execute by user/group/others). diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go index 8f9f574ce..a4503f5cf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go @@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time" // are used to provide nanoscond precision. Hence, these timestamps will now // overflow in May 2446. // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps. +// +// +marshal type InodeNew struct { InodeOld diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go index db25b11b6..e6b28babf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go @@ -30,6 +30,8 @@ const ( // // All fields representing time are in seconds since the epoch. Which means that // they will overflow in January 2038. +// +// +marshal type InodeOld struct { ModeRaw uint16 UIDLo uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go index dd03ee50e..90744e956 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go @@ -24,10 +24,12 @@ import ( // TestInodeSize tests that the inode structs are of the correct size. func TestInodeSize(t *testing.T) { - assertSize(t, InodeOld{}, OldInodeSize) + var iOld InodeOld + assertSize(t, &iOld, OldInodeSize) // This was updated from 156 bytes to 160 bytes in Oct 2015. - assertSize(t, InodeNew{}, 160) + var iNew InodeNew + assertSize(t, &iNew, 160) } // TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go index 8bb327006..70948ebe9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + const ( // SbOffset is the absolute offset at which the superblock is placed. SbOffset = 1024 @@ -38,6 +42,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block. type SuperBlock interface { + marshal.Marshallable + // InodesCount returns the total number of inodes in this filesystem. InodesCount() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go index 53e515fd3..4dc6080fb 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go @@ -17,6 +17,8 @@ package disklayout // SuperBlock32Bit implements SuperBlock and represents the 32-bit version of // the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if // RevLevel = DynamicRev and 64-bit feature is disabled. +// +// +marshal type SuperBlock32Bit struct { // We embed the old superblock struct here because the 32-bit version is just // an extension of the old version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go index 7c1053fb4..2c9039327 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go @@ -19,6 +19,8 @@ package disklayout // 1024 bytes (smallest possible block size) and hence the superblock always // fits in no more than one data block. Should only be used when the 64-bit // feature is set. +// +// +marshal type SuperBlock64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go index 9221e0251..e4709f23c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go @@ -16,6 +16,8 @@ package disklayout // SuperBlockOld implements SuperBlock and represents the old version of the // superblock struct. Should be used only if RevLevel = OldRev. +// +// +marshal type SuperBlockOld struct { InodesCountRaw uint32 BlocksCountLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go index 463b5ba21..b734b6987 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go @@ -21,7 +21,10 @@ import ( // TestSuperBlockSize tests that the superblock structs are of the correct // size. func TestSuperBlockSize(t *testing.T) { - assertSize(t, SuperBlockOld{}, 84) - assertSize(t, SuperBlock32Bit{}, 336) - assertSize(t, SuperBlock64Bit{}, 1024) + var sbOld SuperBlockOld + assertSize(t, &sbOld, 84) + var sb32 SuperBlock32Bit + assertSize(t, &sb32, 336) + var sb64 SuperBlock64Bit + assertSize(t, &sb64, 1024) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go index 9c63f04c0..a4bc08411 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go +++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go @@ -18,13 +18,13 @@ import ( "reflect" "testing" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" ) -func assertSize(t *testing.T, v interface{}, want uintptr) { +func assertSize(t *testing.T, v marshal.Marshallable, want int) { t.Helper() - if got := binary.Size(v); got != want { + if got := v.SizeBytes(); got != want { t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got) } } diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index dac6effbf..aca258d40 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -34,6 +34,8 @@ import ( const Name = "ext" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // Compiles only if FilesystemType implements vfs.FilesystemType. @@ -123,32 +125,32 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs.vfsfs.Init(vfsObj, &fsType, &fs) fs.sb, err = readSuperBlock(dev) if err != nil { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { // mount(2) specifies that EINVAL should be returned if the superblock is // invalid. - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL } // Refuse to mount if the filesystem is incompatible. if !isCompatible(fs.sb) { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL } fs.bgs, err = readBlockGroups(dev, fs.sb) if err != nil { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode) if err != nil { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } rootInode.incRef() diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 64e9a579f..0989558cd 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -65,13 +65,17 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -80,7 +84,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys root := mntns.Root() tearDown := func() { - root.DecRef() + root.DecRef(ctx) if err := f.Close(); err != nil { t.Fatalf("tearDown failed: %v", err) diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index c36225a7c..778460107 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -18,12 +18,13 @@ import ( "io" "sort" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // extentFile is a type of regular file which uses extents to store file data. +// +// +stateify savable type extentFile struct { regFile regularFile @@ -58,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) { func (f *extentFile) buildExtTree() error { rootNodeData := f.regFile.inode.diskInode.Data() - binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header) + f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize]) // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. if f.root.Header.NumEntries > 4 { @@ -77,7 +78,7 @@ func (f *extentFile) buildExtTree() error { // Internal node. curEntry = &disklayout.ExtentIdx{} } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry) + curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize]) f.root.Entries[i].Entry = curEntry } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index cd10d46ee..985f76ac0 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] // writeTree writes the tree represented by `root` to the inode and disk. It // also writes random file data on disk. func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte { - rootData := binary.Marshal(nil, binary.LittleEndian, root.Header) + rootData := in.diskInode.Data() + root.Header.MarshalBytes(rootData) + off := root.Header.SizeBytes() for _, ep := range root.Entries { - rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(rootData[off:]) + off += ep.Entry.SizeBytes() } - copy(in.diskInode.Data(), rootData) - var fileData []byte for _, ep := range root.Entries { if root.Header.Height == 0 { @@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl // writeTreeToDisk is the recursive step for writeTree which writes the tree // on the disk only. Also writes random file data on disk. func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte { - nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header) + nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:] + curNode.Node.Header.MarshalBytes(nodeData) + off := curNode.Node.Header.SizeBytes() for _, ep := range curNode.Node.Entries { - nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(nodeData[off:]) + off += ep.Entry.SizeBytes() } - copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData) - var fileData []byte for _, ep := range curNode.Node.Entries { if curNode.Node.Header.Height == 0 { diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 557963e03..917f1873d 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -38,11 +38,13 @@ var ( ) // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem // mu serializes changes to the Dentry tree. - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` // dev represents the underlying fs device. It does not require protection // because io.ReaderAt permits concurrent read calls to it. It translates to @@ -81,10 +83,10 @@ var _ vfs.FilesystemImpl = (*filesystem)(nil) // stepLocked is loosely analogous to fs/namei.c:walk_component(). // // Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -// - !rp.Done(). -// - inode == vfsd.Impl().(*Dentry).inode. -func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { +// * filesystem.mu must be locked (for writing if write param is true). +// * !rp.Done(). +// * inode == vfsd.Impl().(*Dentry).inode. +func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { if !inode.isDir() { return nil, nil, syserror.ENOTDIR } @@ -100,7 +102,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo } d := vfsd.Impl().(*dentry) if name == ".." { - isRoot, err := rp.CheckRoot(vfsd) + isRoot, err := rp.CheckRoot(ctx, vfsd) if err != nil { return nil, nil, err } @@ -108,7 +110,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo rp.Advance() return vfsd, inode, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, nil, err } rp.Advance() @@ -143,7 +145,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo child.name = name dir.childCache[name] = child } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, nil, err } if child.inode.isSymlink() && rp.ShouldFollowSymlink() { @@ -166,13 +168,13 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo // walkLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). // // Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { +// * filesystem.mu must be locked (for writing if write param is true). +func walkLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() inode := vfsd.Impl().(*dentry).inode for !rp.Done() { var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode, write) + vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write) if err != nil { return nil, nil, err } @@ -194,14 +196,14 @@ func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) // walkParentLocked is loosely analogous to Linux's fs/namei.c:path_parentat(). // // Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -// - !rp.Done(). -func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { +// * filesystem.mu must be locked (for writing if write param is true). +// * !rp.Done(). +func walkParentLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() inode := vfsd.Impl().(*dentry).inode for !rp.Final() { var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode, write) + vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write) if err != nil { return nil, nil, err } @@ -216,7 +218,7 @@ func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, e // the rp till the parent of the last component which should be an existing // directory. If parent is false then resolves rp entirely. Attemps to resolve // the path as far as it can with a read lock and upgrades the lock if needed. -func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) { +func (fs *filesystem) walk(ctx context.Context, rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) { var ( vfsd *vfs.Dentry inode *inode @@ -227,9 +229,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in // of disk. This reduces congestion (allows concurrent walks). fs.mu.RLock() if parent { - vfsd, inode, err = walkParentLocked(rp, false) + vfsd, inode, err = walkParentLocked(ctx, rp, false) } else { - vfsd, inode, err = walkLocked(rp, false) + vfsd, inode, err = walkLocked(ctx, rp, false) } fs.mu.RUnlock() @@ -238,9 +240,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in // walk is fine as this is a read only filesystem. fs.mu.Lock() if parent { - vfsd, inode, err = walkParentLocked(rp, true) + vfsd, inode, err = walkParentLocked(ctx, rp, true) } else { - vfsd, inode, err = walkLocked(rp, true) + vfsd, inode, err = walkLocked(ctx, rp, true) } fs.mu.Unlock() } @@ -283,7 +285,7 @@ func (fs *filesystem) statTo(stat *linux.Statfs) { // AccessAt implements vfs.Filesystem.Impl.AccessAt. func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -292,7 +294,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - vfsd, inode, err := fs.walk(rp, false) + vfsd, inode, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -312,7 +314,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op // GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { - vfsd, inode, err := fs.walk(rp, true) + vfsd, inode, err := fs.walk(ctx, rp, true) if err != nil { return nil, err } @@ -322,7 +324,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - vfsd, inode, err := fs.walk(rp, false) + vfsd, inode, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -336,7 +338,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return "", err } @@ -349,7 +351,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // StatAt implements vfs.FilesystemImpl.StatAt. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return linux.Statx{}, err } @@ -360,7 +362,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // StatFSAt implements vfs.FilesystemImpl.StatFSAt. func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - if _, _, err := fs.walk(rp, false); err != nil { + if _, _, err := fs.walk(ctx, rp, false); err != nil { return linux.Statfs{}, err } @@ -370,7 +372,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } @@ -390,7 +392,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EEXIST } - if _, _, err := fs.walk(rp, true); err != nil { + if _, _, err := fs.walk(ctx, rp, true); err != nil { return err } @@ -403,7 +405,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return syserror.EEXIST } - if _, _, err := fs.walk(rp, true); err != nil { + if _, _, err := fs.walk(ctx, rp, true); err != nil { return err } @@ -416,7 +418,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return syserror.EEXIST } - _, _, err := fs.walk(rp, true) + _, _, err := fs.walk(ctx, rp, true) if err != nil { return err } @@ -430,7 +432,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return syserror.ENOENT } - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -440,7 +442,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // RmdirAt implements vfs.FilesystemImpl.RmdirAt. func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -454,7 +456,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -468,7 +470,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ return syserror.EEXIST } - _, _, err := fs.walk(rp, true) + _, _, err := fs.walk(ctx, rp, true) if err != nil { return err } @@ -478,7 +480,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ // UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -490,9 +492,9 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -504,36 +506,36 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { - _, _, err := fs.walk(rp, false) +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { + _, _, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } return nil, syserror.ENOTSUP } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { - _, _, err := fs.walk(rp, false) +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { + _, _, err := fs.walk(ctx, rp, false) if err != nil { return "", err } return "", syserror.ENOTSUP } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { - _, _, err := fs.walk(rp, false) +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } return syserror.ENOTSUP } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { - _, _, err := fs.walk(rp, false) +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 30636cf66..9009ba3c7 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -37,6 +37,8 @@ import ( // |-- regular-- // |-- extent file // |-- block map file +// +// +stateify savable type inode struct { // refs is a reference count. refs is accessed using atomic memory operations. refs int64 diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 66d14bb95..4a5539b37 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -31,6 +31,8 @@ import ( // regularFile represents a regular file's inode. This too follows the // inheritance pattern prevelant in the vfs layer described in // pkg/sentry/vfs/README.md. +// +// +stateify savable type regularFile struct { inode inode @@ -67,6 +69,8 @@ func (in *inode) isRegular() bool { // directoryFD represents a directory file description. It implements // vfs.FileDescriptionImpl. +// +// +stateify savable type regularFileFD struct { fileDescription vfs.LockFD @@ -75,11 +79,11 @@ type regularFileFD struct { off int64 // offMu serializes operations that may mutate off. - offMu sync.Mutex + offMu sync.Mutex `state:"nosave"` } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() {} +func (fd *regularFileFD) Release(context.Context) {} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index 62efd4095..5e2bcc837 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -23,6 +23,8 @@ import ( ) // symlink represents a symlink inode. +// +// +stateify savable type symlink struct { inode inode target string // immutable @@ -61,9 +63,11 @@ func (in *inode) isSymlink() bool { return ok } -// symlinkFD represents a symlink file description and implements implements +// symlinkFD represents a symlink file description and implements // vfs.FileDescriptionImpl. which may only be used if open options contains // O_PATH. For this reason most of the functions return EBADF. +// +// +stateify savable type symlinkFD struct { fileDescription vfs.NoLockFD @@ -73,7 +77,7 @@ type symlinkFD struct { var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil) // Release implements vfs.FileDescriptionImpl.Release. -func (fd *symlinkFD) Release() {} +func (fd *symlinkFD) Release(context.Context) {} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go index d8b728f8c..58ef7b9b8 100644 --- a/pkg/sentry/fsimpl/ext/utils.go +++ b/pkg/sentry/fsimpl/ext/utils.go @@ -17,21 +17,21 @@ package ext import ( "io" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // readFromDisk performs a binary read from disk into the given struct from // the absolute offset provided. -func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error { - n := binary.Size(v) +func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error { + n := v.SizeBytes() buf := make([]byte, n) if read, _ := dev.ReadAt(buf, abOff); read < int(n) { return syserror.EIO } - binary.Unmarshal(buf, binary.LittleEndian, v) + v.UnmarshalBytes(buf) return nil } diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 41567967d..045d7ab08 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -1,19 +1,86 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "request_list", + out = "request_list.go", + package = "fuse", + prefix = "request", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Request", + "Linker": "*Request", + }, +) + +go_template_instance( + name = "inode_refs", + out = "inode_refs.go", + package = "fuse", + prefix = "inode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "inode", + }, +) + go_library( name = "fuse", srcs = [ + "connection.go", + "connection_control.go", "dev.go", + "directory.go", + "file.go", + "fusefs.go", + "inode_refs.go", + "read_write.go", + "register.go", + "regular_file.go", + "request_list.go", + "request_response.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", + "//pkg/marshal", + "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "fuse_test", + size = "small", + srcs = [ + "connection_test.go", + "dev_test.go", + "utils_test.go", + ], + library = ":fuse", + deps = [ + "//pkg/abi/linux", + "//pkg/marshal", + "//pkg/sentry/fsimpl/testutil", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", + "//pkg/waiter", ], ) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go new file mode 100644 index 000000000..8ccda1264 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -0,0 +1,322 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // fuseDefaultMaxBackground is the default value for MaxBackground. + fuseDefaultMaxBackground = 12 + + // fuseDefaultCongestionThreshold is the default value for CongestionThreshold, + // and is 75% of the default maximum of MaxGround. + fuseDefaultCongestionThreshold = (fuseDefaultMaxBackground * 3 / 4) + + // fuseDefaultMaxPagesPerReq is the default value for MaxPagesPerReq. + fuseDefaultMaxPagesPerReq = 32 +) + +// connection is the struct by which the sentry communicates with the FUSE server daemon. +// +// Lock order: +// - conn.fd.mu +// - conn.mu +// - conn.asyncMu +// +// +stateify savable +type connection struct { + fd *DeviceFD + + // mu protects access to struct memebers. + mu sync.Mutex `state:"nosave"` + + // attributeVersion is the version of connection's attributes. + attributeVersion uint64 + + // We target FUSE 7.23. + // The following FUSE_INIT flags are currently unsupported by this implementation: + // - FUSE_EXPORT_SUPPORT + // - FUSE_POSIX_LOCKS: requires POSIX locks + // - FUSE_FLOCK_LOCKS: requires POSIX locks + // - FUSE_AUTO_INVAL_DATA: requires page caching eviction + // - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation + // - FUSE_ASYNC_DIO + // - FUSE_PARALLEL_DIROPS (7.25) + // - FUSE_HANDLE_KILLPRIV (7.26) + // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler (7.26) + // - FUSE_ABORT_ERROR (7.27) + // - FUSE_CACHE_SYMLINKS (7.28) + // - FUSE_NO_OPENDIR_SUPPORT (7.29) + // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction (7.30) + // - FUSE_MAP_ALIGNMENT (7.31) + + // initialized after receiving FUSE_INIT reply. + // Until it's set, suspend sending FUSE requests. + // Use SetInitialized() and IsInitialized() for atomic access. + initialized int32 + + // initializedChan is used to block requests before initialization. + initializedChan chan struct{} `state:".(bool)"` + + // connected (connection established) when a new FUSE file system is created. + // Set to false when: + // umount, + // connection abort, + // device release. + connected bool + + // connInitError if FUSE_INIT encountered error (major version mismatch). + // Only set in INIT. + connInitError bool + + // connInitSuccess if FUSE_INIT is successful. + // Only set in INIT. + // Used for destory (not yet implemented). + connInitSuccess bool + + // aborted via sysfs, and will send ECONNABORTED to read after disconnection (instead of ENODEV). + // Set only if abortErr is true and via fuse control fs (not yet implemented). + // TODO(gvisor.dev/issue/3525): set this to true when user aborts. + aborted bool + + // numWating is the number of requests waiting to be + // sent to FUSE device or being processed by FUSE daemon. + numWaiting uint32 + + // Terminology note: + // + // - `asyncNumMax` is the `MaxBackground` in the FUSE_INIT_IN struct. + // + // - `asyncCongestionThreshold` is the `CongestionThreshold` in the FUSE_INIT_IN struct. + // + // We call the "background" requests in unix term as async requests. + // The "async requests" in unix term is our async requests that expect a reply, + // i.e. `!request.noReply` + + // asyncMu protects the async request fields. + asyncMu sync.Mutex `state:"nosave"` + + // asyncNum is the number of async requests. + // Protected by asyncMu. + asyncNum uint16 + + // asyncCongestionThreshold the number of async requests. + // Negotiated in FUSE_INIT as "CongestionThreshold". + // TODO(gvisor.dev/issue/3529): add congestion control. + // Protected by asyncMu. + asyncCongestionThreshold uint16 + + // asyncNumMax is the maximum number of asyncNum. + // Connection blocks the async requests when it is reached. + // Negotiated in FUSE_INIT as "MaxBackground". + // Protected by asyncMu. + asyncNumMax uint16 + + // maxRead is the maximum size of a read buffer in in bytes. + // Initialized from a fuse fs parameter. + maxRead uint32 + + // maxWrite is the maximum size of a write buffer in bytes. + // Negotiated in FUSE_INIT. + maxWrite uint32 + + // maxPages is the maximum number of pages for a single request to use. + // Negotiated in FUSE_INIT. + maxPages uint16 + + // minor version of the FUSE protocol. + // Negotiated and only set in INIT. + minor uint32 + + // atomicOTrunc is true when FUSE does not send a separate SETATTR request + // before open with O_TRUNC flag. + // Negotiated and only set in INIT. + atomicOTrunc bool + + // asyncRead if read pages asynchronously. + // Negotiated and only set in INIT. + asyncRead bool + + // writebackCache is true for write-back cache policy, + // false for write-through policy. + // Negotiated and only set in INIT. + writebackCache bool + + // bigWrites if doing multi-page cached writes. + // Negotiated and only set in INIT. + bigWrites bool + + // dontMask if filestestem does not apply umask to creation modes. + // Negotiated in INIT. + dontMask bool + + // noOpen if FUSE server doesn't support open operation. + // This flag only influence performance, not correctness of the program. + noOpen bool +} + +func (conn *connection) saveInitializedChan() bool { + select { + case <-conn.initializedChan: + return true // Closed. + default: + return false // Not closed. + } +} + +func (conn *connection) loadInitializedChan(closed bool) { + conn.initializedChan = make(chan struct{}, 1) + if closed { + close(conn.initializedChan) + } +} + +// newFUSEConnection creates a FUSE connection to fd. +func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, opts *filesystemOptions) (*connection, error) { + // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to + // mount a FUSE filesystem. + fuseFD := fd.Impl().(*DeviceFD) + + // Create the writeBuf for the header to be stored in. + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + fuseFD.writeBuf = make([]byte, hdrLen) + fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse) + fuseFD.fullQueueCh = make(chan struct{}, opts.maxActiveRequests) + fuseFD.writeCursor = 0 + + return &connection{ + fd: fuseFD, + asyncNumMax: fuseDefaultMaxBackground, + asyncCongestionThreshold: fuseDefaultCongestionThreshold, + maxRead: opts.maxRead, + maxPages: fuseDefaultMaxPagesPerReq, + initializedChan: make(chan struct{}), + connected: true, + }, nil +} + +// CallAsync makes an async (aka background) request. +// It's a simple wrapper around Call(). +func (conn *connection) CallAsync(t *kernel.Task, r *Request) error { + r.async = true + _, err := conn.Call(t, r) + return err +} + +// Call makes a request to the server. +// Block before the connection is initialized. +// When the Request is FUSE_INIT, it will not be blocked before initialization. +// Task should never be nil. +// +// For a sync request, it blocks the invoking task until +// a server responds with a response. +// +// For an async request (that do not expect a response immediately), +// it returns directly unless being blocked either before initialization +// or when there are too many async requests ongoing. +// +// Example for async request: +// init, readahead, write, async read/write, fuse_notify_reply, +// non-sync release, interrupt, forget. +// +// The forget request does not have a reply, +// as documented in include/uapi/linux/fuse.h:FUSE_FORGET. +func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) { + // Block requests sent before connection is initalized. + if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT { + if err := t.Block(conn.initializedChan); err != nil { + return nil, err + } + } + + if !conn.connected { + return nil, syserror.ENOTCONN + } + + if conn.connInitError { + return nil, syserror.ECONNREFUSED + } + + fut, err := conn.callFuture(t, r) + if err != nil { + return nil, err + } + + return fut.resolve(t) +} + +// callFuture makes a request to the server and returns a future response. +// Call resolve() when the response needs to be fulfilled. +func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + + // Is the queue full? + // + // We must busy wait here until the request can be queued. We don't + // block on the fd.fullQueueCh with a lock - so after being signalled, + // before we acquire the lock, it is possible that a barging task enters + // and queues a request. As a result, upon acquiring the lock we must + // again check if the room is available. + // + // This can potentially starve a request forever but this can only happen + // if there are always too many ongoing requests all the time. The + // supported maxActiveRequests setting should be really high to avoid this. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + log.Infof("Blocking request %v from being queued. Too many active requests: %v", + r.id, conn.fd.numActiveRequests) + conn.fd.mu.Unlock() + err := t.Block(conn.fd.fullQueueCh) + conn.fd.mu.Lock() + if err != nil { + return nil, err + } + } + + return conn.callFutureLocked(t, r) +} + +// callFutureLocked makes a request to the server and returns a future response. +func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) { + // Check connected again holding conn.mu. + conn.mu.Lock() + if !conn.connected { + conn.mu.Unlock() + // we checked connected before, + // this must be due to aborted connection. + return nil, syserror.ECONNABORTED + } + conn.mu.Unlock() + + conn.fd.queue.PushBack(r) + conn.fd.numActiveRequests++ + fut := newFutureResponse(r) + conn.fd.completions[r.id] = fut + + // Signal the readers that there is something to read. + conn.fd.waitQueue.Notify(waiter.EventIn) + + return fut, nil +} diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go new file mode 100644 index 000000000..bfde78559 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -0,0 +1,247 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// consts used by FUSE_INIT negotiation. +const ( + // fuseMaxMaxPages is the maximum value for MaxPages received in InitOut. + // Follow the same behavior as unix fuse implementation. + fuseMaxMaxPages = 256 + + // Maximum value for the time granularity for file time stamps, 1s. + // Follow the same behavior as unix fuse implementation. + fuseMaxTimeGranNs = 1000000000 + + // Minimum value for MaxWrite and MaxRead. + // Follow the same behavior as unix fuse implementation. + fuseMinMaxWrite = 4096 + fuseMinMaxRead = 4096 + + // Temporary default value for max readahead, 128kb. + fuseDefaultMaxReadahead = 131072 + + // The FUSE_INIT_IN flags sent to the daemon. + // TODO(gvisor.dev/issue/3199): complete the flags. + fuseDefaultInitFlags = linux.FUSE_MAX_PAGES +) + +// Adjustable maximums for Connection's cogestion control parameters. +// Used as the upperbound of the config values. +// Currently we do not support adjustment to them. +var ( + MaxUserBackgroundRequest uint16 = fuseDefaultMaxBackground + MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold +) + +// SetInitialized atomically sets the connection as initialized. +func (conn *connection) SetInitialized() { + // Unblock the requests sent before INIT. + close(conn.initializedChan) + + // Close the channel first to avoid the non-atomic situation + // where conn.initialized is true but there are + // tasks being blocked on the channel. + // And it prevents the newer tasks from gaining + // unnecessary higher chance to be issued before the blocked one. + + atomic.StoreInt32(&(conn.initialized), int32(1)) +} + +// IsInitialized atomically check if the connection is initialized. +// pairs with SetInitialized(). +func (conn *connection) Initialized() bool { + return atomic.LoadInt32(&(conn.initialized)) != 0 +} + +// InitSend sends a FUSE_INIT request. +func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { + in := linux.FUSEInitIn{ + Major: linux.FUSE_KERNEL_VERSION, + Minor: linux.FUSE_KERNEL_MINOR_VERSION, + // TODO(gvisor.dev/issue/3196): find appropriate way to calculate this + MaxReadahead: fuseDefaultMaxReadahead, + Flags: fuseDefaultInitFlags, + } + + req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) + if err != nil { + return err + } + + // Since there is no task to block on and FUSE_INIT is the request + // to unblock other requests, use nil. + return conn.CallAsync(nil, req) +} + +// InitRecv receives a FUSE_INIT reply and process it. +// +// Preconditions: conn.asyncMu must not be held if minor verion is newer than 13. +func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error { + if err := res.Error(); err != nil { + return err + } + + initRes := fuseInitRes{initLen: res.DataLen()} + if err := res.UnmarshalPayload(&initRes); err != nil { + return err + } + + return conn.initProcessReply(&initRes.initOut, hasSysAdminCap) +} + +// Process the FUSE_INIT reply from the FUSE server. +// It tries to acquire the conn.asyncMu lock if minor version is newer than 13. +func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error { + // No matter error or not, always set initialzied. + // to unblock the blocked requests. + defer conn.SetInitialized() + + // No support for old major fuse versions. + if out.Major != linux.FUSE_KERNEL_VERSION { + conn.connInitError = true + return nil + } + + // Start processing the reply. + conn.connInitSuccess = true + conn.minor = out.Minor + + // No support for negotiating MaxWrite before minor version 5. + if out.Minor >= 5 { + conn.maxWrite = out.MaxWrite + } else { + conn.maxWrite = fuseMinMaxWrite + } + if conn.maxWrite < fuseMinMaxWrite { + conn.maxWrite = fuseMinMaxWrite + } + + // No support for the following flags before minor version 6. + if out.Minor >= 6 { + conn.asyncRead = out.Flags&linux.FUSE_ASYNC_READ != 0 + conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0 + conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0 + conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0 + + // TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs). + + if out.Flags&linux.FUSE_MAX_PAGES != 0 { + maxPages := out.MaxPages + if maxPages < 1 { + maxPages = 1 + } + if maxPages > fuseMaxMaxPages { + maxPages = fuseMaxMaxPages + } + conn.maxPages = maxPages + } + } + + // No support for limits before minor version 13. + if out.Minor >= 13 { + conn.asyncMu.Lock() + + if out.MaxBackground > 0 { + conn.asyncNumMax = out.MaxBackground + + if !hasSysAdminCap && + conn.asyncNumMax > MaxUserBackgroundRequest { + conn.asyncNumMax = MaxUserBackgroundRequest + } + } + + if out.CongestionThreshold > 0 { + conn.asyncCongestionThreshold = out.CongestionThreshold + + if !hasSysAdminCap && + conn.asyncCongestionThreshold > MaxUserCongestionThreshold { + conn.asyncCongestionThreshold = MaxUserCongestionThreshold + } + } + + conn.asyncMu.Unlock() + } + + return nil +} + +// Abort this FUSE connection. +// It tries to acquire conn.fd.mu, conn.lock, conn.bgLock in order. +// All possible requests waiting or blocking will be aborted. +// +// Preconditions: conn.fd.mu is locked. +func (conn *connection) Abort(ctx context.Context) { + conn.mu.Lock() + conn.asyncMu.Lock() + + if !conn.connected { + conn.asyncMu.Unlock() + conn.mu.Unlock() + conn.fd.mu.Unlock() + return + } + + conn.connected = false + + // Empty the `fd.queue` that holds the requests + // not yet read by the FUSE daemon yet. + // These are a subset of the requests in `fuse.completion` map. + for !conn.fd.queue.Empty() { + req := conn.fd.queue.Front() + conn.fd.queue.Remove(req) + } + + var terminate []linux.FUSEOpID + + // 2. Collect the requests have not been sent to FUSE daemon, + // or have not received a reply. + for unique := range conn.fd.completions { + terminate = append(terminate, unique) + } + + // Release locks to avoid deadlock. + conn.asyncMu.Unlock() + conn.mu.Unlock() + + // 1. The requets blocked before initialization. + // Will reach call() `connected` check and return. + if !conn.Initialized() { + conn.SetInitialized() + } + + // 2. Terminate the requests collected above. + // Set ECONNABORTED error. + // sendError() will remove them from `fd.completion` map. + // Will enter the path of a normally received error. + for _, toTerminate := range terminate { + conn.fd.sendError(ctx, -int32(syscall.ECONNABORTED), toTerminate) + } + + // 3. The requests not yet written to FUSE device. + // Early terminate. + // Will reach callFutureLocked() `connected` check and return. + close(conn.fd.fullQueueCh) + + // TODO(gvisor.dev/issue/3528): Forget all pending forget reqs. +} diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go new file mode 100644 index 000000000..91d16c1cf --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "math/rand" + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" +) + +// TestConnectionInitBlock tests if initialization +// correctly blocks and unblocks the connection. +// Since it's unfeasible to test kernelTask.Block() in unit test, +// the code in Call() are not tested here. +func TestConnectionInitBlock(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + + conn, _, err := newTestConnection(s, k, maxActiveRequestsDefault) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + select { + case <-conn.initializedChan: + t.Fatalf("initializedChan should be blocking before SetInitialized") + default: + } + + conn.SetInitialized() + + select { + case <-conn.initializedChan: + default: + t.Fatalf("initializedChan should not be blocking after SetInitialized") + } +} + +func TestConnectionAbort(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + task := kernel.TaskFromContext(s.Ctx) + + const numRequests uint64 = 256 + + conn, _, err := newTestConnection(s, k, numRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + testObj := &testPayload{ + data: rand.Uint32(), + } + + var futNormal []*futureResponse + + for i := 0; i < int(numRequests); i++ { + req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + fut, err := conn.callFutureLocked(task, req) + if err != nil { + t.Fatalf("callFutureLocked failed: %v", err) + } + futNormal = append(futNormal, fut) + } + + conn.Abort(s.Ctx) + + // Abort should unblock the initialization channel. + // Note: no test requests are actually blocked on `conn.initializedChan`. + select { + case <-conn.initializedChan: + default: + t.Fatalf("initializedChan should not be blocking after SetInitialized") + } + + // Abort will return ECONNABORTED error to unblocked requests. + for _, fut := range futNormal { + if fut.getResponse().hdr.Error != -int32(syscall.ECONNABORTED) { + t.Fatalf("Incorrect error code received for aborted connection: %v", fut.getResponse().hdr.Error) + } + } + + // After abort, Call() should return directly with ENOTCONN. + req, err := conn.NewRequest(creds, 0, 0, 0, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + _, err = conn.Call(task, req) + if err != syserror.ENOTCONN { + t.Fatalf("Incorrect error code received for Call() after connection aborted") + } + +} diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index f6a67d005..1b86a4b4c 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -15,21 +15,32 @@ package fuse import ( + "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" ) const fuseDevMinor = 229 // fuseDevice implements vfs.Device for /dev/fuse. +// +// +stateify savable type fuseDevice struct{} // Open implements vfs.Device.Open. func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + if !kernel.FUSEEnabled { + return nil, syserror.ENOENT + } + var fd DeviceFD if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ UseDentryMetadata: true, @@ -40,60 +51,412 @@ func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse. +// +// +stateify savable type DeviceFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl vfs.NoLockFD - // TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue - // and deque requests, control synchronization and establish communication - // between the FUSE kernel module and the /dev/fuse character device. + // nextOpID is used to create new requests. + nextOpID linux.FUSEOpID + + // queue is the list of requests that need to be processed by the FUSE server. + queue requestList + + // numActiveRequests is the number of requests made by the Sentry that has + // yet to be responded to. + numActiveRequests uint64 + + // completions is used to map a request to its response. A Writer will use this + // to notify the caller of a completed response. + completions map[linux.FUSEOpID]*futureResponse + + writeCursor uint32 + + // writeBuf is the memory buffer used to copy in the FUSE out header from + // userspace. + writeBuf []byte + + // writeCursorFR current FR being copied from server. + writeCursorFR *futureResponse + + // mu protects all the queues, maps, buffers and cursors and nextOpID. + mu sync.Mutex `state:"nosave"` + + // waitQueue is used to notify interested parties when the device becomes + // readable or writable. + waitQueue waiter.Queue + + // fullQueueCh is a channel used to synchronize the readers with the writers. + // Writers (inbound requests to the filesystem) block if there are too many + // unprocessed in-flight requests. + fullQueueCh chan struct{} `state:".(int)"` + + // fs is the FUSE filesystem that this FD is being used for. + fs *filesystem +} + +func (fd *DeviceFD) saveFullQueueCh() int { + return cap(fd.fullQueueCh) +} + +func (fd *DeviceFD) loadFullQueueCh(capacity int) { + fd.fullQueueCh = make(chan struct{}, capacity) } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DeviceFD) Release() {} +func (fd *DeviceFD) Release(ctx context.Context) { + if fd.fs != nil { + fd.fs.conn.mu.Lock() + fd.fs.conn.connected = false + fd.fs.conn.mu.Unlock() + + fd.fs.VFSFilesystem().DecRef(ctx) + fd.fs = nil + } +} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if fd.fs == nil { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } // Read implements vfs.FileDescriptionImpl.Read. func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return 0, syserror.ENOSYS + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if fd.fs == nil { + return 0, syserror.EPERM + } + + // Return ENODEV if the filesystem is umounted. + if fd.fs.umounted { + // TODO(gvisor.dev/issue/3525): return ECONNABORTED if aborted via fuse control fs. + return 0, syserror.ENODEV + } + + // We require that any Read done on this filesystem have a sane minimum + // read buffer. It must have the capacity for the fixed parts of any request + // header (Linux uses the request header and the FUSEWriteIn header for this + // calculation) + the negotiated MaxWrite room for the data. + minBuffSize := linux.FUSE_MIN_READ_BUFFER + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes()) + negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.maxWrite + if minBuffSize < negotiatedMinBuffSize { + minBuffSize = negotiatedMinBuffSize + } + + // If the read buffer is too small, error out. + if dst.NumBytes() < int64(minBuffSize) { + return 0, syserror.EINVAL + } + + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readLocked(ctx, dst, opts) +} + +// readLocked implements the reading of the fuse device while locked with DeviceFD.mu. +// +// Preconditions: dst is large enough for any reasonable request. +func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + var req *Request + + // Find the first valid request. + // For the normal case this loop only execute once. + for !fd.queue.Empty() { + req = fd.queue.Front() + + if int64(req.hdr.Len)+int64(len(req.payload)) <= dst.NumBytes() { + break + } + + // The request is too large. Cannot process it. All requests must be smaller than the + // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT + // handshake. + errno := -int32(syscall.EIO) + if req.hdr.Opcode == linux.FUSE_SETXATTR { + errno = -int32(syscall.E2BIG) + } + + // Return the error to the calling task. + if err := fd.sendError(ctx, errno, req.hdr.Unique); err != nil { + return 0, err + } + + // We're done with this request. + fd.queue.Remove(req) + req = nil + } + + if req == nil { + return 0, syserror.ErrWouldBlock + } + + // We already checked the size: dst must be able to fit the whole request. + // Now we write the marshalled header, the payload, + // and the potential additional payload + // to the user memory IOSequence. + + n, err := dst.CopyOut(ctx, req.data) + if err != nil { + return 0, err + } + if n != len(req.data) { + return 0, syserror.EIO + } + + if req.hdr.Opcode == linux.FUSE_WRITE { + written, err := dst.DropFirst(n).CopyOut(ctx, req.payload) + if err != nil { + return 0, err + } + if written != len(req.payload) { + return 0, syserror.EIO + } + n += int(written) + } + + // Fully done with this req, remove it from the queue. + fd.queue.Remove(req) + + // Remove noReply ones from map of requests expecting a reply. + if req.noReply { + fd.numActiveRequests -= 1 + delete(fd.completions, req.hdr.Unique) + } + + return int64(n), nil } // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if fd.fs == nil { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } // Write implements vfs.FileDescriptionImpl.Write. func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.ENOSYS + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.writeLocked(ctx, src, opts) +} + +// writeLocked implements writing to the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if fd.fs == nil { + return 0, syserror.EPERM + } + + // Return ENODEV if the filesystem is umounted. + if fd.fs.umounted { + return 0, syserror.ENODEV + } + + var cn, n int64 + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + + for src.NumBytes() > 0 { + if fd.writeCursorFR != nil { + // Already have common header, and we're now copying the payload. + wantBytes := fd.writeCursorFR.hdr.Len + + // Note that the FR data doesn't have the header. Copy it over if its necessary. + if fd.writeCursorFR.data == nil { + fd.writeCursorFR.data = make([]byte, wantBytes) + } + + bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == wantBytes { + // Done reading this full response. Clean up and unblock the + // initiator. + break + } + + // Check if we have more data in src. + continue + } + + // Assert that the header isn't read into the writeBuf yet. + if fd.writeCursor >= hdrLen { + return 0, syserror.EINVAL + } + + // We don't have the full common response header yet. + wantBytes := hdrLen - fd.writeCursor + bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == hdrLen { + // Have full header in the writeBuf. Use it to fetch the actual futureResponse + // from the device's completions map. + var hdr linux.FUSEHeaderOut + hdr.UnmarshalBytes(fd.writeBuf) + + // We have the header now and so the writeBuf has served its purpose. + // We could reset it manually here but instead of doing that, at the + // end of the write, the writeCursor will be set to 0 thereby allowing + // the next request to overwrite whats in the buffer, + + fut, ok := fd.completions[hdr.Unique] + if !ok { + // Server sent us a response for a request we never sent, + // or for which we already received a reply (e.g. aborted), an unlikely event. + return 0, syserror.EINVAL + } + + delete(fd.completions, hdr.Unique) + + // Copy over the header into the future response. The rest of the payload + // will be copied over to the FR's data in the next iteration. + fut.hdr = &hdr + fd.writeCursorFR = fut + + // Next iteration will now try read the complete request, if src has + // any data remaining. Otherwise we're done. + } + } + + if fd.writeCursorFR != nil { + if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil { + return 0, err + } + + // Ready the device for the next request. + fd.writeCursorFR = nil + fd.writeCursor = 0 + } + + return n, nil +} + +// Readiness implements vfs.FileDescriptionImpl.Readiness. +func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readinessLocked(mask) +} + +// readinessLocked implements checking the readiness of the fuse device while +// locked with DeviceFD.mu. +func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask { + var ready waiter.EventMask + + if fd.fs.umounted { + ready |= waiter.EventErr + return ready & mask + } + + // FD is always writable. + ready |= waiter.EventOut + if !fd.queue.Empty() { + // Have reqs available, FD is readable. + ready |= waiter.EventIn + } + + return ready & mask +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.waitQueue.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *DeviceFD) EventUnregister(e *waiter.Entry) { + fd.waitQueue.EventUnregister(e) } // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if fd.fs == nil { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } -// Register registers the FUSE device with vfsObj. -func Register(vfsObj *vfs.VirtualFilesystem) error { - if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{ - GroupName: "misc", - }); err != nil { - return err +// sendResponse sends a response to the waiting task (if any). +// +// Preconditions: fd.mu must be held. +func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error { + // Signal the task waiting on a response if any. + defer close(fut.ch) + + // Signal that the queue is no longer full. + select { + case fd.fullQueueCh <- struct{}{}: + default: + } + fd.numActiveRequests-- + + if fut.async { + return fd.asyncCallBack(ctx, fut.getResponse()) } return nil } -// CreateDevtmpfsFile creates a device special file in devtmpfs. -func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error { - if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil { - return err +// sendError sends an error response to the waiting task (if any) by calling sendResponse(). +// +// Preconditions: fd.mu must be held. +func (fd *DeviceFD) sendError(ctx context.Context, errno int32, unique linux.FUSEOpID) error { + // Return the error to the calling task. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + respHdr := linux.FUSEHeaderOut{ + Len: outHdrLen, + Error: errno, + Unique: unique, + } + + fut, ok := fd.completions[respHdr.Unique] + if !ok { + // A response for a request we never sent, + // or for which we already received a reply (e.g. aborted). + return syserror.EINVAL + } + delete(fd.completions, respHdr.Unique) + + fut.hdr = &respHdr + return fd.sendResponse(ctx, fut) +} + +// asyncCallBack executes pre-defined callback function for async requests. +// Currently used by: FUSE_INIT. +func (fd *DeviceFD) asyncCallBack(ctx context.Context, r *Response) error { + switch r.opcode { + case linux.FUSE_INIT: + creds := auth.CredentialsFromContext(ctx) + rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace() + return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs)) + // TODO(gvisor.dev/issue/3247): support async read: correctly process the response. } return nil diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go new file mode 100644 index 000000000..5986133e9 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -0,0 +1,323 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// echoTestOpcode is the Opcode used during testing. The server used in tests +// will simply echo the payload back with the appropriate headers. +const echoTestOpcode linux.FUSEOpcode = 1000 + +// TestFUSECommunication tests that the communication layer between the Sentry and the +// FUSE server daemon works as expected. +func TestFUSECommunication(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + + // Create test cases with different number of concurrent clients and servers. + testCases := []struct { + Name string + NumClients int + NumServers int + MaxActiveRequests uint64 + }{ + { + Name: "SingleClientSingleServer", + NumClients: 1, + NumServers: 1, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "SingleClientMultipleServers", + NumClients: 1, + NumServers: 10, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "MultipleClientsSingleServer", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "MultipleClientsMultipleServers", + NumClients: 10, + NumServers: 10, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "RequestCapacityFull", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: 1, + }, + { + Name: "RequestCapacityContinuouslyFull", + NumClients: 100, + NumServers: 2, + MaxActiveRequests: 2, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + clientsDone := make([]chan struct{}, testCase.NumClients) + serversDone := make([]chan struct{}, testCase.NumServers) + serversKill := make([]chan struct{}, testCase.NumServers) + + // FUSE clients. + for i := 0; i < testCase.NumClients; i++ { + clientsDone[i] = make(chan struct{}) + go func(i int) { + fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i]) + }(i) + } + + // FUSE servers. + for j := 0; j < testCase.NumServers; j++ { + serversDone[j] = make(chan struct{}) + serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block. + go func(j int) { + fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j]) + }(j) + } + + // Tear down. + // + // Make sure all the clients are done. + for i := 0; i < testCase.NumClients; i++ { + <-clientsDone[i] + } + + // Kill any server that is potentially waiting. + for j := 0; j < testCase.NumServers; j++ { + serversKill[j] <- struct{}{} + } + + // Make sure all the servers are done. + for j := 0; j < testCase.NumServers; j++ { + <-serversDone[j] + } + }) + } +} + +// CallTest makes a request to the server and blocks the invoking +// goroutine until a server responds with a response. Doesn't block +// a kernel.Task. Analogous to Connection.Call but used for testing. +func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) { + conn.fd.mu.Lock() + + // Wait until we're certain that a new request can be processed. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + conn.fd.mu.Unlock() + select { + case <-conn.fd.fullQueueCh: + } + conn.fd.mu.Lock() + } + + fut, err := conn.callFutureLocked(t, r) // No task given. + conn.fd.mu.Unlock() + + if err != nil { + return nil, err + } + + // Resolve the response. + // + // Block without a task. + select { + case <-fut.ch: + } + + // A response is ready. Resolve and return it. + return fut.getResponse(), nil +} + +// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE +// device. However, it does so by - not blocking the task that is calling - and +// instead just waits on a channel. The behaviour is essentially the same as +// DeviceFD.Read except it guarantees that the task is not blocked. +func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) { + var err error + var n, total int64 + + dev := fd.Impl().(*DeviceFD) + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + dev.EventRegister(&w, waiter.EventIn) + for { + // Issue the request and break out if it completes with anything other than + // "would block". + n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{}) + total += n + if err != syserror.ErrWouldBlock { + break + } + + // Wait for a notification that we should retry. + // Emulate the blocking for when no requests are available + select { + case <-ch: + case <-killServer: + // Server killed by the main program. + return 0, true, nil + } + } + + dev.EventUnregister(&w) + return total, false, err +} + +// fuseClientRun emulates all the actions of a normal FUSE request. It creates +// a header, a payload, calls the server, waits for the response, and processes +// the response. +func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) { + defer func() { clientDone <- struct{}{} }() + + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + testObj := &testPayload{ + data: rand.Uint32(), + } + + req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + + // Queue up a request. + // Analogous to Call except it doesn't block on the task. + resp, err := CallTest(conn, clientTask, req, pid) + if err != nil { + t.Fatalf("CallTaskNonBlock failed: %v", err) + } + + if err = resp.Error(); err != nil { + t.Fatalf("Server responded with an error: %v", err) + } + + var respTestPayload testPayload + if err := resp.UnmarshalPayload(&respTestPayload); err != nil { + t.Fatalf("Unmarshalling payload error: %v", err) + } + + if resp.hdr.Unique != req.hdr.Unique { + t.Fatalf("got response for another request. Expected response for req %v but got response for req %v", + req.hdr.Unique, resp.hdr.Unique) + } + + if respTestPayload.data != testObj.data { + t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + } + +} + +// fuseServerRun creates a task and emulates all the actions of a simple FUSE server +// that simply reads a request and echos the same struct back as a response using the +// appropriate headers. +func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) { + defer func() { serverDone <- struct{}{} }() + + // Create the tasks that the server will be using. + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + var readPayload testPayload + + serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + + // Read the request. + for { + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + payloadLen := uint32(readPayload.SizeBytes()) + + // The raed buffer must meet some certain size criteria. + buffSize := inHdrLen + payloadLen + if buffSize < linux.FUSE_MIN_READ_BUFFER { + buffSize = linux.FUSE_MIN_READ_BUFFER + } + inBuf := make([]byte, buffSize) + inIOseq := usermem.BytesIOSequence(inBuf) + + n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer) + if err != nil { + t.Fatalf("Read failed :%v", err) + } + + // Server should shut down. No new requests are going to be made. + if serverKilled { + break + } + + if n <= 0 { + t.Fatalf("Read read no bytes") + } + + var readFUSEHeaderIn linux.FUSEHeaderIn + readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) + readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + + if readFUSEHeaderIn.Opcode != echoTestOpcode { + t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) + } + + // Write the response. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + outBuf := make([]byte, outHdrLen+payloadLen) + outHeader := linux.FUSEHeaderOut{ + Len: outHdrLen + payloadLen, + Error: 0, + Unique: readFUSEHeaderIn.Unique, + } + + // Echo the payload back. + outHeader.MarshalUnsafe(outBuf[:outHdrLen]) + readPayload.MarshalUnsafe(outBuf[outHdrLen:]) + outIOseq := usermem.BytesIOSequence(outBuf) + + n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed :%v", err) + } + } +} diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go new file mode 100644 index 000000000..8f220a04b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type directoryFD struct { + fileDescription +} + +// Allocate implements directoryFD.Allocate. +func (*directoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EISDIR +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (*directoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (*directoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (*directoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (*directoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback) error { + fusefs := dir.inode().fs + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + + in := linux.FUSEReadIn{ + Fh: dir.Fh, + Offset: uint64(atomic.LoadInt64(&dir.off)), + Size: linux.FUSE_PAGE_SIZE, + Flags: dir.statusFlags(), + } + + // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS. + req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) + if err != nil { + return err + } + + res, err := fusefs.conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + + var out linux.FUSEDirents + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + for _, fuseDirent := range out.Dirents { + nextOff := int64(fuseDirent.Meta.Off) + dirent := vfs.Dirent{ + Name: fuseDirent.Name, + Type: uint8(fuseDirent.Meta.Type), + Ino: fuseDirent.Meta.Ino, + NextOff: nextOff, + } + + if err := callback.Handle(dirent); err != nil { + return err + } + atomic.StoreInt64(&dir.off, nextOff) + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go new file mode 100644 index 000000000..83f2816b7 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/file.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// fileDescription implements vfs.FileDescriptionImpl for fuse. +type fileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + // the file handle used in userspace. + Fh uint64 + + // Nonseekable is indicate cannot perform seek on a file. + Nonseekable bool + + // DirectIO suggest fuse to use direct io operation. + DirectIO bool + + // OpenFlag is the flag returned by open. + OpenFlag uint32 + + // off is the file offset. + off int64 +} + +func (fd *fileDescription) dentry() *kernfs.Dentry { + return fd.vfsfd.Dentry().Impl().(*kernfs.Dentry) +} + +func (fd *fileDescription) inode() *inode { + return fd.dentry().Inode().(*inode) +} + +func (fd *fileDescription) filesystem() *vfs.Filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem() +} + +func (fd *fileDescription) statusFlags() uint32 { + return fd.vfsfd.StatusFlags() +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *fileDescription) Release(ctx context.Context) { + // no need to release if FUSE server doesn't implement Open. + conn := fd.inode().fs.conn + if conn.noOpen { + return + } + + in := linux.FUSEReleaseIn{ + Fh: fd.Fh, + Flags: fd.statusFlags(), + } + // TODO(gvisor.dev/issue/3245): add logic when we support file lock owner. + var opcode linux.FUSEOpcode + if fd.inode().Mode().IsDir() { + opcode = linux.FUSE_RELEASEDIR + } else { + opcode = linux.FUSE_RELEASE + } + kernelTask := kernel.TaskFromContext(ctx) + // ignoring errors and FUSE server reply is analogous to Linux's behavior. + req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) + if err != nil { + // No way to invoke Call() with an errored request. + return + } + // The reply will be ignored since no callback is defined in asyncCallBack(). + conn.CallAsync(kernelTask, req) +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, nil +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return 0, nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.filesystem() + inode := fd.inode() + return inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + fs := fd.filesystem() + creds := auth.CredentialsFromContext(ctx) + return fd.inode().setAttr(ctx, fs, creds, opts, true, fd.Fh) +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go new file mode 100644 index 000000000..65786e42a --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -0,0 +1,826 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fuse implements fusefs. +package fuse + +import ( + "math" + "strconv" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Name is the default filesystem name. +const Name = "fuse" + +// maxActiveRequestsDefault is the default setting controlling the upper bound +// on the number of active requests at any given time. +const maxActiveRequestsDefault = 10000 + +// FilesystemType implements vfs.FilesystemType. +// +// +stateify savable +type FilesystemType struct{} + +// +stateify savable +type filesystemOptions struct { + // userID specifies the numeric uid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + userID uint32 + + // groupID specifies the numeric gid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + groupID uint32 + + // rootMode specifies the the file mode of the filesystem's root. + rootMode linux.FileMode + + // maxActiveRequests specifies the maximum number of active requests that can + // exist at any time. Any further requests will block when trying to + // Call the server. + maxActiveRequests uint64 + + // maxRead is the max number of bytes to read, + // specified as "max_read" in fs parameters. + // If not specified by user, use math.MaxUint32 as default value. + maxRead uint32 +} + +// filesystem implements vfs.FilesystemImpl. +// +// +stateify savable +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // conn is used for communication between the FUSE server + // daemon and the sentry fusefs. + conn *connection + + // opts is the options the fusefs is initialized with. + opts *filesystemOptions + + // umounted is true if filesystem.Release() has been called. + umounted bool +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + var fsopts filesystemOptions + mopts := vfs.GenericParseMountOptions(opts.Data) + deviceDescriptorStr, ok := mopts["fd"] + if !ok { + log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name()) + return nil, nil, syserror.EINVAL + } + delete(mopts, "fd") + + deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) + if err != nil { + return nil, nil, err + } + + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) + return nil, nil, syserror.EINVAL + } + fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + + // Parse and set all the other supported FUSE mount options. + // TODO(gVisor.dev/issue/3229): Expand the supported mount options. + if userIDStr, ok := mopts["user_id"]; ok { + delete(mopts, "user_id") + userID, err := strconv.ParseUint(userIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.userID = uint32(userID) + } + + if groupIDStr, ok := mopts["group_id"]; ok { + delete(mopts, "group_id") + groupID, err := strconv.ParseUint(groupIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.groupID = uint32(groupID) + } + + rootMode := linux.FileMode(0777) + modeStr, ok := mopts["rootmode"] + if ok { + delete(mopts, "rootmode") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) + return nil, nil, syserror.EINVAL + } + rootMode = linux.FileMode(mode) + } + fsopts.rootMode = rootMode + + // Set the maxInFlightRequests option. + fsopts.maxActiveRequests = maxActiveRequestsDefault + + if maxReadStr, ok := mopts["max_read"]; ok { + delete(mopts, "max_read") + maxRead, err := strconv.ParseUint(maxReadStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid max_read: max_read=%s", fsType.Name(), maxReadStr) + return nil, nil, syserror.EINVAL + } + if maxRead < fuseMinMaxRead { + maxRead = fuseMinMaxRead + } + fsopts.maxRead = uint32(maxRead) + } else { + fsopts.maxRead = math.MaxUint32 + } + + // Check for unparsed options. + if len(mopts) != 0 { + log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts) + return nil, nil, syserror.EINVAL + } + + // Create a new FUSE filesystem. + fs, err := newFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + if err != nil { + log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) + return nil, nil, err + } + + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + // Send a FUSE_INIT request to the FUSE daemon server before returning. + // This call is not blocking. + if err := fs.conn.InitSend(creds, uint32(kernelTask.ThreadID())); err != nil { + log.Warningf("%s.InitSend: failed with error: %v", fsType.Name(), err) + return nil, nil, err + } + + // root is the fusefs root directory. + root := fs.newRootInode(creds, fsopts.rootMode) + + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// newFUSEFilesystem creates a new FUSE filesystem. +func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { + conn, err := newFUSEConnection(ctx, device, opts) + if err != nil { + log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) + return nil, syserror.EINVAL + } + + fuseFD := device.Impl().(*DeviceFD) + + fs := &filesystem{ + devMinor: devMinor, + opts: opts, + conn: conn, + } + + fs.VFSFilesystem().IncRef() + fuseFD.fs = fs + + return fs, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release(ctx context.Context) { + fs.conn.fd.mu.Lock() + + fs.umounted = true + fs.conn.Abort(ctx) + // Notify all the waiters on this fd. + fs.conn.fd.waitQueue.Notify(waiter.EventIn) + + fs.conn.fd.mu.Unlock() + + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release(ctx) +} + +// inode implements kernfs.Inode. +// +// +stateify savable +type inode struct { + inodeRefs + kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNoDynamicLookup + kernfs.InodeNotSymlink + kernfs.OrderedChildren + + dentry kernfs.Dentry + + // the owning filesystem. fs is immutable. + fs *filesystem + + // metaDataMu protects the metadata of this inode. + metadataMu sync.Mutex + + nodeID uint64 + + locks vfs.FileLocks + + // size of the file. + size uint64 + + // attributeVersion is the version of inode's attributes. + attributeVersion uint64 + + // attributeTime is the remaining vaild time of attributes. + attributeTime uint64 + + // version of the inode. + version uint64 + + // link is result of following a symbolic link. + link string +} + +func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { + i := &inode{fs: fs} + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.EnableLeakCheck() + i.dentry.Init(i) + i.nodeID = 1 + + return &i.dentry +} + +func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) *kernfs.Dentry { + i := &inode{fs: fs, nodeID: nodeID} + creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} + i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + atomic.StoreUint64(&i.size, attr.Size) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.EnableLeakCheck() + i.dentry.Init(i) + + return &i.dentry +} + +// Open implements kernfs.Inode.Open. +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + isDir := i.InodeAttrs.Mode().IsDir() + // return error if specified to open directory but inode is not a directory. + if !isDir && opts.Mode.IsDir() { + return nil, syserror.ENOTDIR + } + if opts.Flags&linux.O_LARGEFILE == 0 && atomic.LoadUint64(&i.size) > linux.MAX_NON_LFS { + return nil, syserror.EOVERFLOW + } + + var fd *fileDescription + var fdImpl vfs.FileDescriptionImpl + if isDir { + directoryFD := &directoryFD{} + fd = &(directoryFD.fileDescription) + fdImpl = directoryFD + } else { + regularFD := ®ularFileFD{} + fd = &(regularFD.fileDescription) + fdImpl = regularFD + } + // FOPEN_KEEP_CACHE is the defualt flag for noOpen. + fd.OpenFlag = linux.FOPEN_KEEP_CACHE + + // Only send open request when FUSE server support open or is opening a directory. + if !i.fs.conn.noOpen || isDir { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.Open: couldn't get kernel task from context") + return nil, syserror.EINVAL + } + + // Build the request. + var opcode linux.FUSEOpcode + if isDir { + opcode = linux.FUSE_OPENDIR + } else { + opcode = linux.FUSE_OPEN + } + + in := linux.FUSEOpenIn{Flags: opts.Flags & ^uint32(linux.O_CREAT|linux.O_EXCL|linux.O_NOCTTY)} + if !i.fs.conn.atomicOTrunc { + in.Flags &= ^uint32(linux.O_TRUNC) + } + + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) + if err != nil { + return nil, err + } + + // Send the request and receive the reply. + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return nil, err + } + if err := res.Error(); err == syserror.ENOSYS && !isDir { + i.fs.conn.noOpen = true + } else if err != nil { + return nil, err + } else { + out := linux.FUSEOpenOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return nil, err + } + + // Process the reply. + fd.OpenFlag = out.OpenFlag + if isDir { + fd.OpenFlag &= ^uint32(linux.FOPEN_DIRECT_IO) + } + + fd.Fh = out.Fh + } + } + + // TODO(gvisor.dev/issue/3234): invalidate mmap after implemented it for FUSE Inode + fd.DirectIO = fd.OpenFlag&linux.FOPEN_DIRECT_IO != 0 + fdOptions := &vfs.FileDescriptionOptions{} + if fd.OpenFlag&linux.FOPEN_NONSEEKABLE != 0 { + fdOptions.DenyPRead = true + fdOptions.DenyPWrite = true + fd.Nonseekable = true + } + + // If we don't send SETATTR before open (which is indicated by atomicOTrunc) + // and O_TRUNC is set, update the inode's version number and clean existing data + // by setting the file size to 0. + if i.fs.conn.atomicOTrunc && opts.Flags&linux.O_TRUNC != 0 { + i.fs.conn.mu.Lock() + i.fs.conn.attributeVersion++ + i.attributeVersion = i.fs.conn.attributeVersion + atomic.StoreUint64(&i.size, 0) + i.fs.conn.mu.Unlock() + i.attributeTime = 0 + } + + if err := fd.vfsfd.Init(fdImpl, opts.Flags, rp.Mount(), d.VFSDentry(), fdOptions); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// Lookup implements kernfs.Inode.Lookup. +func (i *inode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { + in := linux.FUSELookupIn{Name: name} + return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in) +} + +// IterDirents implements kernfs.Inode.IterDirents. +func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return offset, nil +} + +// Valid implements kernfs.Inode.Valid. +func (*inode) Valid(ctx context.Context) bool { + return true +} + +// NewFile implements kernfs.Inode.NewFile. +func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID) + return nil, syserror.EINVAL + } + in := linux.FUSECreateIn{ + CreateMeta: linux.FUSECreateMeta{ + Flags: opts.Flags, + Mode: uint32(opts.Mode) | linux.S_IFREG, + Umask: uint32(kernelTask.FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in) +} + +// NewNode implements kernfs.Inode.NewNode. +func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*kernfs.Dentry, error) { + in := linux.FUSEMknodIn{ + MknodMeta: linux.FUSEMknodMeta{ + Mode: uint32(opts.Mode), + Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor), + Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in) +} + +// NewSymlink implements kernfs.Inode.NewSymlink. +func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.Dentry, error) { + in := linux.FUSESymLinkIn{ + Name: name, + Target: target, + } + return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in) +} + +// Unlink implements kernfs.Inode.Unlink. +func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) error { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) + return syserror.EINVAL + } + in := linux.FUSEUnlinkIn{Name: name} + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) + if err != nil { + return err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return err + } + // only return error, discard res. + if err := res.Error(); err != nil { + return err + } + return i.dentry.RemoveChildLocked(name, child) +} + +// NewDir implements kernfs.Inode.NewDir. +func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) { + in := linux.FUSEMkdirIn{ + MkdirMeta: linux.FUSEMkdirMeta{ + Mode: uint32(opts.Mode), + Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in) +} + +// RmDir implements kernfs.Inode.RmDir. +func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) error { + fusefs := i.fs + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + + in := linux.FUSERmDirIn{Name: name} + req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) + if err != nil { + return err + } + + res, err := i.fs.conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + + return i.dentry.RemoveChildLocked(name, child) +} + +// newEntry calls FUSE server for entry creation and allocates corresponding entry according to response. +// Shared by FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and FUSE_LOOKUP. +func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*kernfs.Dentry, error) { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) + return nil, syserror.EINVAL + } + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) + if err != nil { + return nil, err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return nil, err + } + if err := res.Error(); err != nil { + return nil, err + } + out := linux.FUSEEntryOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return nil, err + } + if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { + return nil, syserror.EIO + } + child := i.fs.newInode(out.NodeID, out.Attr) + return child, nil +} + +// Getlink implements kernfs.Inode.Getlink. +func (i *inode) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + path, err := i.Readlink(ctx, mnt) + return vfs.VirtualDentry{}, path, err +} + +// Readlink implements kernfs.Inode.Readlink. +func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { + if i.Mode().FileType()&linux.S_IFLNK == 0 { + return "", syserror.EINVAL + } + if len(i.link) == 0 { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") + return "", syserror.EINVAL + } + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) + if err != nil { + return "", err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return "", err + } + i.link = string(res.data[res.hdr.SizeBytes():]) + if !mnt.Options().ReadOnly { + i.attributeTime = 0 + } + } + return i.link, nil +} + +// getFUSEAttr returns a linux.FUSEAttr of this inode stored in local cache. +// TODO(gvisor.dev/issue/3679): Add support for other fields. +func (i *inode) getFUSEAttr() linux.FUSEAttr { + return linux.FUSEAttr{ + Ino: i.Ino(), + Size: atomic.LoadUint64(&i.size), + Mode: uint32(i.Mode()), + } +} + +// statFromFUSEAttr makes attributes from linux.FUSEAttr to linux.Statx. The +// opts.Sync attribute is ignored since the synchronization is handled by the +// FUSE server. +func statFromFUSEAttr(attr linux.FUSEAttr, mask, devMinor uint32) linux.Statx { + var stat linux.Statx + stat.Blksize = attr.BlkSize + stat.DevMajor, stat.DevMinor = linux.UNNAMED_MAJOR, devMinor + + rdevMajor, rdevMinor := linux.DecodeDeviceID(attr.Rdev) + stat.RdevMajor, stat.RdevMinor = uint32(rdevMajor), rdevMinor + + if mask&linux.STATX_MODE != 0 { + stat.Mode = uint16(attr.Mode) + } + if mask&linux.STATX_NLINK != 0 { + stat.Nlink = attr.Nlink + } + if mask&linux.STATX_UID != 0 { + stat.UID = attr.UID + } + if mask&linux.STATX_GID != 0 { + stat.GID = attr.GID + } + if mask&linux.STATX_ATIME != 0 { + stat.Atime = linux.StatxTimestamp{ + Sec: int64(attr.Atime), + Nsec: attr.AtimeNsec, + } + } + if mask&linux.STATX_MTIME != 0 { + stat.Mtime = linux.StatxTimestamp{ + Sec: int64(attr.Mtime), + Nsec: attr.MtimeNsec, + } + } + if mask&linux.STATX_CTIME != 0 { + stat.Ctime = linux.StatxTimestamp{ + Sec: int64(attr.Ctime), + Nsec: attr.CtimeNsec, + } + } + if mask&linux.STATX_INO != 0 { + stat.Ino = attr.Ino + } + if mask&linux.STATX_SIZE != 0 { + stat.Size = attr.Size + } + if mask&linux.STATX_BLOCKS != 0 { + stat.Blocks = attr.Blocks + } + return stat +} + +// getAttr gets the attribute of this inode by issuing a FUSE_GETATTR request +// or read from local cache. It updates the corresponding attributes if +// necessary. +func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions, flags uint32, fh uint64) (linux.FUSEAttr, error) { + attributeVersion := atomic.LoadUint64(&i.fs.conn.attributeVersion) + + // TODO(gvisor.dev/issue/3679): send the request only if + // - invalid local cache for fields specified in the opts.Mask + // - forced update + // - i.attributeTime expired + // If local cache is still valid, return local cache. + // Currently we always send a request, + // and we always set the metadata with the new result, + // unless attributeVersion has changed. + + task := kernel.TaskFromContext(ctx) + if task == nil { + log.Warningf("couldn't get kernel task from context") + return linux.FUSEAttr{}, syserror.EINVAL + } + + creds := auth.CredentialsFromContext(ctx) + + in := linux.FUSEGetAttrIn{ + GetAttrFlags: flags, + Fh: fh, + } + req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) + if err != nil { + return linux.FUSEAttr{}, err + } + + res, err := i.fs.conn.Call(task, req) + if err != nil { + return linux.FUSEAttr{}, err + } + if err := res.Error(); err != nil { + return linux.FUSEAttr{}, err + } + + var out linux.FUSEGetAttrOut + if err := res.UnmarshalPayload(&out); err != nil { + return linux.FUSEAttr{}, err + } + + // Local version is newer, return the local one. + // Skip the update. + if attributeVersion != 0 && atomic.LoadUint64(&i.attributeVersion) > attributeVersion { + return i.getFUSEAttr(), nil + } + + // Set the metadata of kernfs.InodeAttrs. + if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), + }); err != nil { + return linux.FUSEAttr{}, err + } + + // Set the size if no error (after SetStat() check). + atomic.StoreUint64(&i.size, out.Attr.Size) + + return out.Attr, nil +} + +// reviseAttr attempts to update the attributes for internal purposes +// by calling getAttr with a pre-specified mask. +// Used by read, write, lseek. +func (i *inode) reviseAttr(ctx context.Context, flags uint32, fh uint64) error { + // Never need atime for internal purposes. + _, err := i.getAttr(ctx, i.fs.VFSFilesystem(), vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS &^ linux.STATX_ATIME, + }, flags, fh) + return err +} + +// Stat implements kernfs.Inode.Stat. +func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + attr, err := i.getAttr(ctx, fs, opts, 0, 0) + if err != nil { + return linux.Statx{}, err + } + + return statFromFUSEAttr(attr, opts.Mask, i.fs.devMinor), nil +} + +// DecRef implements kernfs.Inode.DecRef. +func (i *inode) DecRef(context.Context) { + i.inodeRefs.DecRef(i.Destroy) +} + +// StatFS implements kernfs.Inode.StatFS. +func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + // TODO(gvisor.dev/issues/3413): Complete the implementation of statfs. + return vfs.GenericStatFS(linux.FUSE_SUPER_MAGIC), nil +} + +// fattrMaskFromStats converts vfs.SetStatOptions.Stat.Mask to linux stats mask +// aligned with the attribute mask defined in include/linux/fs.h. +func fattrMaskFromStats(mask uint32) uint32 { + var fuseAttrMask uint32 + maskMap := map[uint32]uint32{ + linux.STATX_MODE: linux.FATTR_MODE, + linux.STATX_UID: linux.FATTR_UID, + linux.STATX_GID: linux.FATTR_GID, + linux.STATX_SIZE: linux.FATTR_SIZE, + linux.STATX_ATIME: linux.FATTR_ATIME, + linux.STATX_MTIME: linux.FATTR_MTIME, + linux.STATX_CTIME: linux.FATTR_CTIME, + } + for statxMask, fattrMask := range maskMap { + if mask&statxMask != 0 { + fuseAttrMask |= fattrMask + } + } + return fuseAttrMask +} + +// SetStat implements kernfs.Inode.SetStat. +func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return i.setAttr(ctx, fs, creds, opts, false, 0) +} + +func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions, useFh bool, fh uint64) error { + conn := i.fs.conn + task := kernel.TaskFromContext(ctx) + if task == nil { + log.Warningf("couldn't get kernel task from context") + return syserror.EINVAL + } + + // We should retain the original file type when assigning new mode. + fileType := uint16(i.Mode()) & linux.S_IFMT + fattrMask := fattrMaskFromStats(opts.Stat.Mask) + if useFh { + fattrMask |= linux.FATTR_FH + } + in := linux.FUSESetAttrIn{ + Valid: fattrMask, + Fh: fh, + Size: opts.Stat.Size, + Atime: uint64(opts.Stat.Atime.Sec), + Mtime: uint64(opts.Stat.Mtime.Sec), + Ctime: uint64(opts.Stat.Ctime.Sec), + AtimeNsec: opts.Stat.Atime.Nsec, + MtimeNsec: opts.Stat.Mtime.Nsec, + CtimeNsec: opts.Stat.Ctime.Nsec, + Mode: uint32(fileType | opts.Stat.Mode), + UID: opts.Stat.UID, + GID: opts.Stat.GID, + } + req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) + if err != nil { + return err + } + + res, err := conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + out := linux.FUSEGetAttrOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + // Set the metadata of kernfs.InodeAttrs. + if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), + }); err != nil { + return err + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go new file mode 100644 index 000000000..625d1547f --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -0,0 +1,242 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// ReadInPages sends FUSE_READ requests for the size after round it up to +// a multiple of page size, blocks on it for reply, processes the reply +// and returns the payload (or joined payloads) as a byte slice. +// This is used for the general purpose reading. +// We do not support direct IO (which read the exact number of bytes) +// at this moment. +func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off uint64, size uint32) ([][]byte, uint32, error) { + attributeVersion := atomic.LoadUint64(&fs.conn.attributeVersion) + + t := kernel.TaskFromContext(ctx) + if t == nil { + log.Warningf("fusefs.Read: couldn't get kernel task from context") + return nil, 0, syserror.EINVAL + } + + // Round up to a multiple of page size. + readSize, _ := usermem.PageRoundUp(uint64(size)) + + // One request cannnot exceed either maxRead or maxPages. + maxPages := fs.conn.maxRead >> usermem.PageShift + if maxPages > uint32(fs.conn.maxPages) { + maxPages = uint32(fs.conn.maxPages) + } + + var outs [][]byte + var sizeRead uint32 + + // readSize is a multiple of usermem.PageSize. + // Always request bytes as a multiple of pages. + pagesRead, pagesToRead := uint32(0), uint32(readSize>>usermem.PageShift) + + // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. + in := linux.FUSEReadIn{ + Fh: fd.Fh, + LockOwner: 0, // TODO(gvisor.dev/issue/3245): file lock + ReadFlags: 0, // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER + Flags: fd.statusFlags(), + } + + // This loop is intended for fragmented read where the bytes to read is + // larger than either the maxPages or maxRead. + // For the majority of reads with normal size, this loop should only + // execute once. + for pagesRead < pagesToRead { + pagesCanRead := pagesToRead - pagesRead + if pagesCanRead > maxPages { + pagesCanRead = maxPages + } + + in.Offset = off + (uint64(pagesRead) << usermem.PageShift) + in.Size = pagesCanRead << usermem.PageShift + + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) + if err != nil { + return nil, 0, err + } + + // TODO(gvisor.dev/issue/3247): support async read. + + res, err := fs.conn.Call(t, req) + if err != nil { + return nil, 0, err + } + if err := res.Error(); err != nil { + return nil, 0, err + } + + // Not enough bytes in response, + // either we reached EOF, + // or the FUSE server sends back a response + // that cannot even fit the hdr. + if len(res.data) <= res.hdr.SizeBytes() { + // We treat both case as EOF here for now + // since there is no reliable way to detect + // the over-short hdr case. + break + } + + // Directly using the slice to avoid extra copy. + out := res.data[res.hdr.SizeBytes():] + + outs = append(outs, out) + sizeRead += uint32(len(out)) + + pagesRead += pagesCanRead + } + + defer fs.ReadCallback(ctx, fd, off, size, sizeRead, attributeVersion) + + // No bytes returned: offset >= EOF. + if len(outs) == 0 { + return nil, 0, io.EOF + } + + return outs, sizeRead, nil +} + +// ReadCallback updates several information after receiving a read response. +// Due to readahead, sizeRead can be larger than size. +func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off uint64, size uint32, sizeRead uint32, attributeVersion uint64) { + // TODO(gvisor.dev/issue/3247): support async read. + // If this is called by an async read, correctly process it. + // May need to update the signature. + + i := fd.inode() + // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + + // Reached EOF. + if sizeRead < size { + // TODO(gvisor.dev/issue/3630): If we have writeback cache, then we need to fill this hole. + // Might need to update the buf to be returned from the Read(). + + // Update existing size. + newSize := off + uint64(sizeRead) + fs.conn.mu.Lock() + if attributeVersion == i.attributeVersion && newSize < atomic.LoadUint64(&i.size) { + fs.conn.attributeVersion++ + i.attributeVersion = i.fs.conn.attributeVersion + atomic.StoreUint64(&i.size, newSize) + } + fs.conn.mu.Unlock() + } +} + +// Write sends FUSE_WRITE requests and return the bytes +// written according to the response. +// +// Preconditions: len(data) == size. +func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, size uint32, data []byte) (uint32, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + log.Warningf("fusefs.Read: couldn't get kernel task from context") + return 0, syserror.EINVAL + } + + // One request cannnot exceed either maxWrite or maxPages. + maxWrite := uint32(fs.conn.maxPages) << usermem.PageShift + if maxWrite > fs.conn.maxWrite { + maxWrite = fs.conn.maxWrite + } + + // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. + in := linux.FUSEWriteIn{ + Fh: fd.Fh, + // TODO(gvisor.dev/issue/3245): file lock + LockOwner: 0, + // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER + // TODO(gvisor.dev/issue/3237): |= linux.FUSE_WRITE_CACHE (not added yet) + WriteFlags: 0, + Flags: fd.statusFlags(), + } + + var written uint32 + + // This loop is intended for fragmented write where the bytes to write is + // larger than either the maxWrite or maxPages or when bigWrites is false. + // Unless a small value for max_write is explicitly used, this loop + // is expected to execute only once for the majority of the writes. + for written < size { + toWrite := size - written + + // Limit the write size to one page. + // Note that the bigWrites flag is obsolete, + // latest libfuse always sets it on. + if !fs.conn.bigWrites && toWrite > usermem.PageSize { + toWrite = usermem.PageSize + } + + // Limit the write size to maxWrite. + if toWrite > maxWrite { + toWrite = maxWrite + } + + in.Offset = off + uint64(written) + in.Size = toWrite + + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + if err != nil { + return 0, err + } + + req.payload = data[written : written+toWrite] + + // TODO(gvisor.dev/issue/3247): support async write. + + res, err := fs.conn.Call(t, req) + if err != nil { + return 0, err + } + if err := res.Error(); err != nil { + return 0, err + } + + out := linux.FUSEWriteOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return 0, err + } + + // Write more than requested? EIO. + if out.Size > toWrite { + return 0, syserror.EIO + } + + written += out.Size + + // Break if short write. Not necessarily an error. + if out.Size != toWrite { + break + } + } + + return written, nil +} diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go new file mode 100644 index 000000000..b5b581152 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/register.go @@ -0,0 +1,42 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// Register registers the FUSE device with vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "misc", + }); err != nil { + return err + } + + return nil +} + +// CreateDevtmpfsFile creates a device special file in devtmpfs. +func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error { + if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil { + return err + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go new file mode 100644 index 000000000..5bdd096c3 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/regular_file.go @@ -0,0 +1,230 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "math" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type regularFileFD struct { + fileDescription + + // off is the file offset. + off int64 + // offMu protects off. + offMu sync.Mutex +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if offset < 0 { + return 0, syserror.EINVAL + } + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + + size := dst.NumBytes() + if size == 0 { + // Early return if count is 0. + return 0, nil + } else if size > math.MaxUint32 { + // FUSE only supports uint32 for size. + // Overflow. + return 0, syserror.EINVAL + } + + // TODO(gvisor.dev/issue/3678): Add direct IO support. + + inode := fd.inode() + + // Reading beyond EOF, update file size if outdated. + if uint64(offset+size) > atomic.LoadUint64(&inode.size) { + if err := inode.reviseAttr(ctx, linux.FUSE_GETATTR_FH, fd.Fh); err != nil { + return 0, err + } + // If the offset after update is still too large, return error. + if uint64(offset) >= atomic.LoadUint64(&inode.size) { + return 0, io.EOF + } + } + + // Truncate the read with updated file size. + fileSize := atomic.LoadUint64(&inode.size) + if uint64(offset+size) > fileSize { + size = int64(fileSize) - offset + } + + buffers, n, err := inode.fs.ReadInPages(ctx, fd, uint64(offset), uint32(size)) + if err != nil { + return 0, err + } + + // TODO(gvisor.dev/issue/3237): support indirect IO (e.g. caching), + // store the bytes that were read ahead. + + // Update the number of bytes to copy for short read. + if n < uint32(size) { + size = int64(n) + } + + // Copy the bytes read to the dst. + // This loop is intended for fragmented reads. + // For the majority of reads, this loop only execute once. + var copied int64 + for _, buffer := range buffers { + toCopy := int64(len(buffer)) + if copied+toCopy > size { + toCopy = size - copied + } + cp, err := dst.DropFirst64(copied).CopyOut(ctx, buffer[:toCopy]) + if err != nil { + return 0, err + } + if int64(cp) != toCopy { + return 0, syserror.EIO + } + copied += toCopy + } + + return copied, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.offMu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.offMu.Lock() + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off + fd.offMu.Unlock() + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { + if offset < 0 { + return 0, offset, syserror.EINVAL + } + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + inode := fd.inode() + inode.metadataMu.Lock() + defer inode.metadataMu.Unlock() + + // If the file is opened with O_APPEND, update offset to file size. + // Note: since our Open() implements the interface of kernfs, + // and kernfs currently does not support O_APPEND, this will never + // be true before we switch out from kernfs. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking inode.metadataMu is sufficient for reading size + offset = int64(inode.size) + } + + srclen := src.NumBytes() + + if srclen > math.MaxUint32 { + // FUSE only supports uint32 for size. + // Overflow. + return 0, offset, syserror.EINVAL + } + if end := offset + srclen; end < offset { + // Overflow. + return 0, offset, syserror.EINVAL + } + + srclen, err = vfs.CheckLimit(ctx, offset, srclen) + if err != nil { + return 0, offset, err + } + + if srclen == 0 { + // Return before causing any side effects. + return 0, offset, nil + } + + src = src.TakeFirst64(srclen) + + // TODO(gvisor.dev/issue/3237): Add cache support: + // buffer cache. Ideally we write from src to our buffer cache first. + // The slice passed to fs.Write() should be a slice from buffer cache. + data := make([]byte, srclen) + // Reason for making a copy here: connection.Call() blocks on kerneltask, + // which in turn acquires mm.activeMu lock. Functions like CopyInTo() will + // attemp to acquire the mm.activeMu lock as well -> deadlock. + // We must finish reading from the userspace memory before + // t.Block() deactivates it. + cp, err := src.CopyIn(ctx, data) + if err != nil { + return 0, offset, err + } + if int64(cp) != srclen { + return 0, offset, syserror.EIO + } + + n, err := fd.inode().fs.Write(ctx, fd, uint64(offset), uint32(srclen), data) + if err != nil { + return 0, offset, err + } + + if n == 0 { + // We have checked srclen != 0 previously. + // If err == nil, then it's a short write and we return EIO. + return 0, offset, syserror.EIO + } + + written = int64(n) + finalOff = offset + written + + if finalOff > int64(inode.size) { + atomic.StoreUint64(&inode.size, uint64(finalOff)) + atomic.AddUint64(&inode.fs.conn.attributeVersion, 1) + } + + return +} diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go new file mode 100644 index 000000000..7fa00569b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" +) + +// fuseInitRes is a variable-length wrapper of linux.FUSEInitOut. The FUSE +// server may implement an older version of FUSE protocol, which contains a +// linux.FUSEInitOut with less attributes. +// +// Dynamically-sized objects cannot be marshalled. +type fuseInitRes struct { + marshal.StubMarshallable + + // initOut contains the response from the FUSE server. + initOut linux.FUSEInitOut + + // initLen is the total length of bytes of the response. + initLen uint32 +} + +// UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes. +func (r *fuseInitRes) UnmarshalBytes(src []byte) { + out := &r.initOut + + // Introduced before FUSE kernel version 7.13. + out.Major = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.Minor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.MaxReadahead = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.Flags = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.MaxBackground = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + out.CongestionThreshold = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + out.MaxWrite = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + + // Introduced in FUSE kernel version 7.23. + if len(src) >= 4 { + out.TimeGran = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + // Introduced in FUSE kernel version 7.28. + if len(src) >= 2 { + out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + } +} + +// SizeBytes is the size of the payload of the FUSE_INIT response. +func (r *fuseInitRes) SizeBytes() int { + return int(r.initLen) +} + +// Ordinary requests have even IDs, while interrupts IDs are odd. +// Used to increment the unique ID for each FUSE request. +var reqIDStep uint64 = 2 + +// Request represents a FUSE operation request that hasn't been sent to the +// server yet. +// +// +stateify savable +type Request struct { + requestEntry + + id linux.FUSEOpID + hdr *linux.FUSEHeaderIn + data []byte + + // payload for this request: extra bytes to write after + // the data slice. Used by FUSE_WRITE. + payload []byte + + // If this request is async. + async bool + // If we don't care its response. + // Manually set by the caller. + noReply bool +} + +// NewRequest creates a new request that can be sent to the FUSE server. +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) + + hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() + hdr := linux.FUSEHeaderIn{ + Len: uint32(hdrLen + payload.SizeBytes()), + Opcode: opcode, + Unique: conn.fd.nextOpID, + NodeID: ino, + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + PID: pid, + } + + buf := make([]byte, hdr.Len) + + // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again. + hdr.MarshalBytes(buf[:hdrLen]) + payload.MarshalBytes(buf[hdrLen:]) + + return &Request{ + id: hdr.Unique, + hdr: &hdr, + data: buf, + }, nil +} + +// futureResponse represents an in-flight request, that may or may not have +// completed yet. Convert it to a resolved Response by calling Resolve, but note +// that this may block. +// +// +stateify savable +type futureResponse struct { + opcode linux.FUSEOpcode + ch chan struct{} + hdr *linux.FUSEHeaderOut + data []byte + + // If this request is async. + async bool +} + +// newFutureResponse creates a future response to a FUSE request. +func newFutureResponse(req *Request) *futureResponse { + return &futureResponse{ + opcode: req.hdr.Opcode, + ch: make(chan struct{}), + async: req.async, + } +} + +// resolve blocks the task until the server responds to its corresponding request, +// then returns a resolved response. +func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { + // Return directly for async requests. + if f.async { + return nil, nil + } + + if err := t.Block(f.ch); err != nil { + return nil, err + } + + return f.getResponse(), nil +} + +// getResponse creates a Response from the data the futureResponse has. +func (f *futureResponse) getResponse() *Response { + return &Response{ + opcode: f.opcode, + hdr: *f.hdr, + data: f.data, + } +} + +// Response represents an actual response from the server, including the +// response payload. +// +// +stateify savable +type Response struct { + opcode linux.FUSEOpcode + hdr linux.FUSEHeaderOut + data []byte +} + +// Error returns the error of the FUSE call. +func (r *Response) Error() error { + errno := r.hdr.Error + if errno >= 0 { + return nil + } + + sysErrNo := syscall.Errno(-errno) + return error(sysErrNo) +} + +// DataLen returns the size of the response without the header. +func (r *Response) DataLen() uint32 { + return r.hdr.Len - uint32(r.hdr.SizeBytes()) +} + +// UnmarshalPayload unmarshals the response data into m. +func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { + hdrLen := r.hdr.SizeBytes() + haveDataLen := r.hdr.Len - uint32(hdrLen) + wantDataLen := uint32(m.SizeBytes()) + + if haveDataLen < wantDataLen { + return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) + } + + // The response data is empty unless there is some payload. And so, doesn't + // need to be unmarshalled. + if r.data == nil { + return nil + } + + // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again. + m.UnmarshalBytes(r.data[hdrLen:]) + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go new file mode 100644 index 000000000..e1d9e3365 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +func setup(t *testing.T) *testutil.System { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("Error creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + AllowUserMount: true, + }) + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + + return testutil.NewSystem(ctx, t, k.VFS(), mntns) +} + +// newTestConnection creates a fuse connection that the sentry can communicate with +// and the FD for the server to communicate with. +func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { + vfsObj := &vfs.VirtualFilesystem{} + fuseDev := &DeviceFD{} + + if err := vfsObj.Init(system.Ctx); err != nil { + return nil, nil, err + } + + vd := vfsObj.NewAnonVirtualDentry("genCountFD") + defer vd.DecRef(system.Ctx) + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, nil, err + } + + fsopts := filesystemOptions{ + maxActiveRequests: maxActiveRequests, + } + fs, err := newFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + if err != nil { + return nil, nil, err + } + + return fs.conn, &fuseDev.vfsfd, nil +} + +type testPayload struct { + marshal.StubMarshallable + data uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *testPayload) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *testPayload) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], t.data) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *testPayload) UnmarshalBytes(src []byte) { + *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} +} + +// Packed implements marshal.Marshallable.Packed. +func (t *testPayload) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *testPayload) MarshalUnsafe(dst []byte) { + t.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *testPayload) UnmarshalUnsafe(src []byte) { + t.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (t *testPayload) CopyOutN(task marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + panic("not implemented") +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *testPayload) CopyOut(task marshal.CopyContext, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *testPayload) CopyIn(task marshal.CopyContext, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *testPayload) WriteTo(w io.Writer) (int64, error) { + panic("not implemented") +} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 4a800dcf9..16787116f 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -85,5 +85,6 @@ go_test( deps = [ "//pkg/p9", "//pkg/sentry/contexttest", + "//pkg/sentry/pgalloc", ], ) diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 8c7c8e1b3..18c884b59 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -34,8 +34,11 @@ func (d *dentry) isDir() bool { return d.fileType() == linux.S_IFDIR } -// Preconditions: filesystem.renameMu must be locked. d.dirMu must be locked. -// d.isDir(). child must be a newly-created dentry that has never had a parent. +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +// * child must be a newly-created dentry that has never had a parent. func (d *dentry) cacheNewChildLocked(child *dentry, name string) { d.IncRef() // reference held by child on its parent child.parent = d @@ -46,7 +49,9 @@ func (d *dentry) cacheNewChildLocked(child *dentry, name string) { d.children[name] = child } -// Preconditions: d.dirMu must be locked. d.isDir(). +// Preconditions: +// * d.dirMu must be locked. +// * d.isDir(). func (d *dentry) cacheNegativeLookupLocked(name string) { // Don't cache negative lookups if InteropModeShared is in effect (since // this makes remote lookup unavoidable), or if d.isSynthetic() (in which @@ -79,10 +84,12 @@ type createSyntheticOpts struct { // createSyntheticChildLocked creates a synthetic file with the given name // in d. // -// Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain -// a child with the given name. +// Preconditions: +// * d.dirMu must be locked. +// * d.isDir(). +// * d does not already contain a child with the given name. func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { - d2 := &dentry{ + child := &dentry{ refs: 1, // held by d fs: d.fs, ino: d.fs.nextSyntheticIno(), @@ -90,39 +97,38 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { uid: uint32(opts.kuid), gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary - handle: handle{ - fd: -1, - }, - nlink: uint32(2), + hostFD: -1, + nlink: uint32(2), } switch opts.mode.FileType() { case linux.S_IFDIR: // Nothing else needs to be done. case linux.S_IFSOCK: - d2.endpoint = opts.endpoint + child.endpoint = opts.endpoint case linux.S_IFIFO: - d2.pipe = opts.pipe + child.pipe = opts.pipe default: panic(fmt.Sprintf("failed to create synthetic file of unrecognized type: %v", opts.mode.FileType())) } - d2.pf.dentry = d2 - d2.vfsd.Init(d2) + child.pf.dentry = child + child.vfsd.Init(child) - d.cacheNewChildLocked(d2, opts.name) + d.cacheNewChildLocked(child, opts.name) d.syntheticChildren++ } +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 dirents []vfs.Dirent } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(context.Context) { } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. @@ -139,7 +145,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.dirents = ds } - d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent) if d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } @@ -153,7 +159,9 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba return nil } -// Preconditions: d.isDir(). There exists at least one directoryFD representing d. +// Preconditions: +// * d.isDir(). +// * There exists at least one directoryFD representing d. func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { // NOTE(b/135560623): 9P2000.L's readdir does not specify behavior in the // presence of concurrent mutation of an iterated directory, so @@ -205,14 +213,14 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { off := uint64(0) const count = 64 * 1024 // for consistency with the vfs1 client d.handleMu.RLock() - if !d.handleReadable { + if d.readFile.isNil() { // This should not be possible because a readable handle should // have been opened when the calling directoryFD was opened. d.handleMu.RUnlock() panic("gofer.dentry.getDirents called without a readable handle") } for { - p9ds, err := d.handle.file.readdir(ctx, off, count) + p9ds, err := d.readFile.readdir(ctx, off, count) if err != nil { d.handleMu.RUnlock() return nil, err @@ -304,5 +312,5 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *directoryFD) Sync(ctx context.Context) error { - return fd.dentry().handle.sync(ctx) + return fd.dentry().syncRemoteFile(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index cd5f5049e..94d96261b 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -15,6 +15,7 @@ package gofer import ( + "math" "sync" "sync/atomic" @@ -54,8 +55,8 @@ func (fs *filesystem) Sync(ctx context.Context) error { // Sync regular files. for _, d := range ds { - err := d.syncSharedHandle(ctx) - d.DecRef() + err := d.syncCachedFile(ctx) + d.DecRef(ctx) if err != nil && retErr == nil { retErr = err } @@ -65,7 +66,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // handles (so they won't be synced by the above). for _, sffd := range sffds { err := sffd.Sync(ctx) - sffd.vfsfd.DecRef() + sffd.vfsfd.DecRef(ctx) if err != nil && retErr == nil { retErr = err } @@ -114,9 +115,12 @@ func putDentrySlice(ds *[]*dentry) { // Dentries which may become cached as a result of the traversal are appended // to *ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). If !d.cachedMetadataAuthoritative(), then d's cached metadata -// must be up to date. +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). +// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up +// to date. // // Postconditions: The returned dentry's cached metadata is up to date. func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { @@ -133,7 +137,7 @@ afterSymlink: return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() @@ -146,15 +150,13 @@ afterSymlink: // // Call rp.CheckMount() before updating d.parent's metadata, since if // we traverse to another mount then d.parent's metadata is irrelevant. - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { - _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { + if err := d.parent.updateFromGetattr(ctx); err != nil { return nil, err } - d.parent.updateFromP9Attrs(attrMask, &attr) } rp.Advance() return d.parent, nil @@ -166,7 +168,7 @@ afterSymlink: if child == nil { return nil, syserror.ENOENT } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { @@ -186,8 +188,11 @@ afterSymlink: // getChildLocked returns a dentry representing the child of parent with the // given name. If no such child exists, getChildLocked returns (nil, nil). // -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. -// parent.isDir(). name is not "." or "..". +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. +// * parent.isDir(). +// * name is not "." or "..". // // Postconditions: If getChildLocked returns a non-nil dentry, its cached // metadata is up to date. @@ -207,19 +212,31 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds) } -// Preconditions: As for getChildLocked. !parent.isSynthetic(). +// Preconditions: Same as getChildLocked, plus: +// * !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { + if child != nil { + // Need to lock child.metadataMu because we might be updating child + // metadata. We need to hold the lock *before* getting metadata from the + // server and release it after updating local metadata. + child.metadataMu.Lock() + } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil && err != syserror.ENOENT { + if child != nil { + child.metadataMu.Unlock() + } return nil, err } if child != nil { if !file.isNil() && inoFromPath(qid.Path) == child.ino { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) + child.updateFromP9AttrsLocked(attrMask, &attr) + child.metadataMu.Unlock() return child, nil } + child.metadataMu.Unlock() if file.isNil() && child.isSynthetic() { // We have a synthetic file, and no remote file has arisen to // replace it. @@ -230,7 +247,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // has 0 references, drop it). Wait to update parent.children until we // know what to replace the existing dentry with (i.e. one of the // returns below), to avoid a redundant map access. - vfsObj.InvalidateDentry(&child.vfsd) + vfsObj.InvalidateDentry(ctx, &child.vfsd) if child.isSynthetic() { // Normally we don't mark invalidated dentries as deleted since // they may still exist (but at a different path), and also for @@ -269,9 +286,11 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). If -// !d.cachedMetadataAuthoritative(), then d's cached metadata must be up to -// date. +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). +// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up +// to date. func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -318,12 +337,13 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // createInRemoteDir (if the parent directory is a real remote directory) or // createInSyntheticDir (if the parent directory is synthetic) to do so. // -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error { +// Preconditions: +// * !rp.Done(). +// * For the final path component in rp, !rp.ShouldFollowSymlink(). +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { // Get updated metadata for start as required by @@ -375,7 +395,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if dir { ev |= linux.IN_ISDIR } - parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } if fs.opts.interop == InteropModeShared { @@ -389,14 +409,14 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a // stale dentry exists, the dentry will fail revalidation next time it's // used. - if err := createInRemoteDir(parent, name); err != nil { + if err := createInRemoteDir(parent, name, &ds); err != nil { return err } ev := linux.IN_CREATE if dir { ev |= linux.IN_ISDIR } - parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } if child := parent.children[name]; child != nil { @@ -404,7 +424,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } // No cached dentry exists; however, there might still be an existing file // at name. As above, we attempt the file creation RPC anyway. - if err := createInRemoteDir(parent, name); err != nil { + if err := createInRemoteDir(parent, name, &ds); err != nil { return err } if child, ok := parent.children[name]; ok && child == nil { @@ -417,7 +437,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if dir { ev |= linux.IN_ISDIR } - parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } @@ -425,7 +445,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { // Get updated metadata for start as required by @@ -461,7 +481,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -591,17 +611,17 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // Generate inotify events for rmdir or unlink. if dir { - parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) + parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) } else { var cw *vfs.Watches if child != nil { cw = &child.watches } - vfs.InotifyRemoveChild(cw, &parent.watches, name) + vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name) } if child != nil { - vfsObj.CommitDeleteDentry(&child.vfsd) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) child.setDeleted() if child.isSynthetic() { parent.syntheticChildren-- @@ -628,7 +648,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { return @@ -636,20 +656,20 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) { if len(**ds) != 0 { fs.renameMu.Lock() for _, d := range **ds { - d.checkCachingLocked() + d.checkCachingLocked(ctx) } fs.renameMu.Unlock() } putDentrySlice(*ds) } -func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) { +func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { if *ds == nil { fs.renameMu.Unlock() return } for _, d := range **ds { - d.checkCachingLocked() + d.checkCachingLocked(ctx) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -659,7 +679,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) { func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -671,7 +691,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -692,7 +712,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { // Get updated metadata for start as required by @@ -711,19 +731,40 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error { if rp.Mount() != vd.Mount() { return syserror.EXDEV } - // 9P2000.L supports hard links, but we don't. - return syserror.EPERM + d := vd.Dentry().Impl().(*dentry) + if d.isDir() { + return syserror.EPERM + } + gid := auth.KGID(atomic.LoadUint32(&d.gid)) + uid := auth.KUID(atomic.LoadUint32(&d.uid)) + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil { + return err + } + if d.nlink == 0 { + return syserror.ENOENT + } + if d.nlink == math.MaxUint32 { + return syserror.EMLINK + } + if err := parent.file.link(ctx, d.file, childName); err != nil { + return err + } + + // Success! + atomic.AddUint32(&d.nlink, 1) + return nil }, nil) } // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil { if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err @@ -758,34 +799,49 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { creds := rp.Credentials() _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - // If the gofer does not allow creating a socket or pipe, create a - // synthetic one, i.e. one that is kept entirely in memory. - if err == syserror.EPERM { - switch opts.Mode.FileType() { - case linux.S_IFSOCK: - parent.createSyntheticChildLocked(&createSyntheticOpts{ - name: name, - mode: opts.Mode, - kuid: creds.EffectiveKUID, - kgid: creds.EffectiveKGID, - endpoint: opts.Endpoint, - }) - return nil - case linux.S_IFIFO: - parent.createSyntheticChildLocked(&createSyntheticOpts{ - name: name, - mode: opts.Mode, - kuid: creds.EffectiveKUID, - kgid: creds.EffectiveKGID, - pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), - }) - return nil - } + if err != syserror.EPERM { + return err } - return err + + // EPERM means that gofer does not allow creating a socket or pipe. Fallback + // to creating a synthetic one, i.e. one that is kept entirely in memory. + + // Check that we're not overriding an existing file with a synthetic one. + _, err = fs.stepLocked(ctx, rp, parent, true, ds) + switch { + case err == nil: + // Step succeeded, another file exists. + return syserror.EEXIST + case err != syserror.ENOENT: + // Unexpected error. + return err + } + + switch opts.Mode.FileType() { + case linux.S_IFSOCK: + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + endpoint: opts.Endpoint, + }) + return nil + case linux.S_IFIFO: + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + }) + return nil + } + // Retain error from gofer if synthetic file cannot be created internally. + return syserror.EPERM }, nil) } @@ -803,7 +859,14 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + unlocked := false + unlock := func() { + if !unlocked { + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + unlocked = true + } + } + defer unlock() start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { @@ -813,7 +876,17 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf } } if rp.Done() { - return start.openLocked(ctx, rp, &opts) + // Reject attempts to open mount root directory with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } + if mustCreate { + return nil, syserror.EEXIST + } + start.IncRef() + defer start.DecRef(ctx) + unlock() + return start.open(ctx, rp, &opts) } afterTrailingSymlink: @@ -825,6 +898,10 @@ afterTrailingSymlink: if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } + // Reject attempts to open directories with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } // Determine whether or not we need to create a file. parent.dirMu.Lock() child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) @@ -844,9 +921,6 @@ afterTrailingSymlink: if mustCreate { return nil, syserror.EEXIST } - if !child.isDir() && rp.MustBeDir() { - return nil, syserror.ENOTDIR - } // Open existing child or follow symlink. if child.isSymlink() && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) @@ -859,11 +933,18 @@ afterTrailingSymlink: start = parent goto afterTrailingSymlink } - return child.openLocked(ctx, rp, &opts) + if rp.MustBeDir() && !child.isDir() { + return nil, syserror.ENOTDIR + } + child.IncRef() + defer child.DecRef(ctx) + unlock() + return child.open(ctx, rp, &opts) } -// Preconditions: fs.renameMu must be locked. -func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +// Preconditions: The caller must hold no locks (since opening pipes may block +// indefinitely). +func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err @@ -926,7 +1007,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, syserror.ENXIO } if d.fs.iopts.OpenSocketsByConnecting { - return d.connectSocketLocked(ctx, opts) + return d.openSocketByConnecting(ctx, opts) } case linux.S_IFIFO: if d.isSynthetic() { @@ -935,7 +1016,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf } if vfd == nil { - if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil { + if vfd, err = d.openSpecialFile(ctx, mnt, opts); err != nil { return nil, err } } @@ -945,7 +1026,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // step is required even if !d.cachedMetadataAuthoritative() because // d.mappings has to be updated. // d.metadataMu has already been acquired if trunc == true. - d.updateFileSizeLocked(0) + d.updateSizeLocked(0) if d.cachedMetadataAuthoritative() { d.touchCMtimeLocked() @@ -954,7 +1035,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return vfd, err } -func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { if opts.Flags&linux.O_DIRECT != 0 { return nil, syserror.EINVAL } @@ -974,7 +1055,7 @@ func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) return fd, nil } -func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if opts.Flags&linux.O_DIRECT != 0 { return nil, syserror.EINVAL @@ -1016,8 +1097,10 @@ retry: return &fd.vfsfd, nil } -// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. -// !d.isSynthetic(). +// Preconditions: +// * d.fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !d.isSynthetic(). func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return nil, err @@ -1040,10 +1123,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } creds := rp.Credentials() name := rp.Component() - // Filter file creation flags and O_LARGEFILE out; the create RPC already - // has the semantics of O_CREAT|O_EXCL, while some servers will choke on - // O_LARGEFILE. - createFlags := p9.OpenFlags(opts.Flags &^ (vfs.FileCreationFlags | linux.O_LARGEFILE)) + // We only want the access mode for creating the file. + createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) if err != nil { dirfile.close(ctx) @@ -1076,12 +1157,18 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { child.handleMu.Lock() - child.handle.file = openFile - if fdobj != nil { - child.handle.fd = int32(fdobj.Release()) + if vfs.MayReadFileWithOpenFlags(opts.Flags) { + child.readFile = openFile + if fdobj != nil { + child.hostFD = int32(fdobj.Release()) + } + } else if fdobj != nil { + // Can't use fdobj if it's not readable. + fdobj.Close() + } + if vfs.MayWriteFileWithOpenFlags(opts.Flags) { + child.writeFile = openFile } - child.handleReadable = vfs.MayReadFileWithOpenFlags(opts.Flags) - child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags) child.handleMu.Unlock() } // Insert the dentry into the tree. @@ -1117,7 +1204,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } - d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) + d.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) return childVFSFD, nil } @@ -1125,7 +1212,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -1145,7 +1232,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa var ds *[]*dentry fs.renameMu.Lock() - defer fs.renameMuUnlockAndCheckCaching(&ds) + defer fs.renameMuUnlockAndCheckCaching(ctx, &ds) newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds) if err != nil { return err @@ -1224,6 +1311,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if !renamed.isDir() { return syserror.EISDIR } + if genericIsAncestorDentry(replaced, renamed) { + return syserror.ENOTEMPTY + } } else { if rp.MustBeDir() || renamed.isDir() { return syserror.ENOTDIR @@ -1235,7 +1325,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return nil } mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil { return err } @@ -1260,7 +1350,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } // Update the dentry tree. - vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD) + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) if replaced != nil { replaced.setDeleted() if replaced.isSynthetic() { @@ -1274,14 +1364,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // with reference counts and queue oldParent for checkCachingLocked if the // parent isn't actually changing. if oldParent != newParent { + oldParent.decRefLocked() ds = appendDentry(ds, oldParent) newParent.IncRef() if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ } + renamed.parent = newParent } - renamed.parent = newParent renamed.name = newName if newParent.children == nil { newParent.children = make(map[string]*dentry) @@ -1322,17 +1413,17 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()); err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + err = d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + if err != nil { return err } - fs.renameMuRUnlockAndCheckCaching(&ds) if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - d.InotifyWithParent(ev, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } @@ -1341,7 +1432,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statx{}, err @@ -1358,7 +1449,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statfs{}, err @@ -1392,7 +1483,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { creds := rp.Credentials() _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) return err @@ -1404,11 +1495,11 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return fs.unlinkAt(ctx, rp, false /* dir */) } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -1425,70 +1516,72 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath path: opts.Addr, }, nil } - return d.endpoint, nil + if d.endpoint != nil { + return d.endpoint, nil + } } return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err } - return d.listxattr(ctx, rp.Credentials(), size) + return d.listXattr(ctx, rp.Credentials(), size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err } - return d.getxattr(ctx, rp.Credentials(), &opts) + return d.getXattr(ctx, rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + err = d.setXattr(ctx, rp.Credentials(), &opts) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + if err != nil { return err } - fs.renameMuRUnlockAndCheckCaching(&ds) - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.removexattr(ctx, rp.Credentials(), name); err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + err = d.removeXattr(ctx, rp.Credentials(), name) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + if err != nil { return err } - fs.renameMuRUnlockAndCheckCaching(&ds) - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 2b83094cd..8608471f8 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -62,9 +62,13 @@ import ( const Name = "9p" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -77,7 +81,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client + client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -95,7 +99,7 @@ type filesystem struct { // reference count (such that it is usable as vfs.ResolvingPath.Start() or // is reachable from its children), or if it is a child dentry (such that // it is reachable from its parent). - renameMu sync.RWMutex + renameMu sync.RWMutex `state:"nosave"` // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) @@ -107,7 +111,7 @@ type filesystem struct { // syncableDentries contains all dentries in this filesystem for which // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. // These fields are protected by syncMu. - syncMu sync.Mutex + syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} @@ -120,6 +124,8 @@ type filesystem struct { // dentries, it comes from QID.Path from the 9P server. Synthetic dentries // have have their inodeNumber generated sequentially, with the MSB reserved to // prevent conflicts with regular dentries. +// +// +stateify savable type inodeNumber uint64 // Reserve MSB for synthetic mounts. @@ -132,6 +138,7 @@ func inoFromPath(path uint64) inodeNumber { return inodeNumber(path &^ syntheticInoMask) } +// +stateify savable type filesystemOptions struct { // "Standard" 9P options. fd int @@ -177,6 +184,8 @@ type filesystemOptions struct { // InteropMode controls the client's interaction with other remote filesystem // users. +// +// +stateify savable type InteropMode uint32 const ( @@ -192,10 +201,10 @@ const ( // // - File timestamps are based on client clocks. This ensures that users of // the client observe timestamps that are coherent with their own clocks - // and consistent with Linux's semantics. However, since it is not always - // possible for clients to set arbitrary atimes and mtimes, and never - // possible for clients to set arbitrary ctimes, file timestamp changes are - // stored in the client only and never sent to the remote filesystem. + // and consistent with Linux's semantics (in particular, it is not always + // possible for clients to set arbitrary atimes and mtimes depending on the + // remote filesystem implementation, and never possible for clients to set + // arbitrary ctimes.) InteropModeExclusive InteropMode = iota // InteropModeWritethrough is appropriate when there are read-only users of @@ -235,6 +244,8 @@ const ( // InternalFilesystemOptions may be passed as // vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem. +// +// +stateify savable type InternalFilesystemOptions struct { // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in @@ -482,7 +493,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { attachFile.close(ctx) - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the @@ -495,17 +506,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { - ctx := context.Background() +func (fs *filesystem) Release(ctx context.Context) { mf := fs.mfp.MemoryFile() fs.syncMu.Lock() for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() - if d.handleWritable { + if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -515,9 +525,9 @@ func (fs *filesystem) Release() { d.dirty.RemoveAll() d.dataMu.Unlock() // Close the host fd if one exists. - if d.handle.fd >= 0 { - syscall.Close(int(d.handle.fd)) - d.handle.fd = -1 + if d.hostFD >= 0 { + syscall.Close(int(d.hostFD)) + d.hostFD = -1 } d.handleMu.Unlock() } @@ -535,6 +545,8 @@ func (fs *filesystem) Release() { } // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -559,14 +571,12 @@ type dentry struct { // filesystem.renameMu. name string - // We don't support hard links, so each dentry maps 1:1 to an inode. - // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file + file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -578,7 +588,7 @@ type dentry struct { cached bool dentryEntry - dirMu sync.Mutex + dirMu sync.Mutex `state:"nosave"` // If this dentry represents a directory, children contains: // @@ -602,9 +612,15 @@ type dentry struct { // returned by the server. dirents is protected by dirMu. dirents []vfs.Dirent - // Cached metadata; protected by metadataMu and accessed using atomic - // memory operations unless otherwise specified. - metadataMu sync.Mutex + // Cached metadata; protected by metadataMu. + // To access: + // - In situations where consistency is not required (like stat), these + // can be accessed using atomic operations only (without locking). + // - Lock metadataMu and can access without atomic operations. + // To mutate: + // - Lock metadataMu and use atomic operations to update because we might + // have atomic readers that don't hold the lock. + metadataMu sync.Mutex `state:"nosave"` ino inodeNumber // immutable mode uint32 // type is immutable, perms are mutable uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic @@ -615,47 +631,56 @@ type dentry struct { mtime int64 ctime int64 btime int64 - // File size, protected by both metadataMu and dataMu (i.e. both must be - // locked to mutate it). + // File size, which differs from other metadata in two ways: + // + // - We make a best-effort attempt to keep it up to date even if + // !dentry.cachedMetadataAuthoritative() for the sake of O_APPEND writes. + // + // - size is protected by both metadataMu and dataMu (i.e. both must be + // locked to mutate it; locking either is sufficient to access it). size uint64 + // If this dentry does not represent a synthetic file, deleted is 0, and + // atimeDirty/mtimeDirty are non-zero, atime/mtime may have diverged from the + // remote file's timestamps, which should be updated when this dentry is + // evicted. + atimeDirty uint32 + mtimeDirty uint32 // nlink counts the number of hard links to this dentry. It's updated and // accessed using atomic operations. It's not protected by metadataMu like the // other metadata fields. nlink uint32 - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` // If this dentry represents a regular file, mappings tracks mappings of // the file into memmap.MappingSpaces. mappings is protected by mapsMu. mappings memmap.MappingSet - // If this dentry represents a regular file or directory: - // - // - handle is the I/O handle used by all regularFileFDs/directoryFDs - // representing this dentry. - // - // - handleReadable is true if handle is readable. - // - // - handleWritable is true if handle is writable. + // - If this dentry represents a regular file or directory, readFile is the + // p9.File used for reads by all regularFileFDs/directoryFDs representing + // this dentry. // - // Invariants: + // - If this dentry represents a regular file, writeFile is the p9.File + // used for writes by all regularFileFDs representing this dentry. // - // - If handleReadable == handleWritable == false, then handle.file == nil - // (i.e. there is no open handle). Conversely, if handleReadable || - // handleWritable == true, then handle.file != nil (i.e. there is an open - // handle). - // - // - handleReadable and handleWritable cannot transition from true to false - // (i.e. handles may not be downgraded). + // - If this dentry represents a regular file, hostFD is the host FD used + // for memory mappings and I/O (when applicable) in preference to readFile + // and writeFile. hostFD is always readable; if !writeFile.isNil(), it must + // also be writable. If hostFD is -1, no such host FD is available. // // These fields are protected by handleMu. - handleMu sync.RWMutex - handle handle - handleReadable bool - handleWritable bool + // + // readFile and writeFile may or may not represent the same p9.File. Once + // either p9.File transitions from closed (isNil() == true) to open + // (isNil() == false), it may be mutated with handleMu locked, but cannot + // be closed until the dentry is destroyed. + handleMu sync.RWMutex `state:"nosave"` + readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + hostFD int32 - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` // If this dentry represents a regular file that is client-cached, cache // maps offsets into the cached file to offsets into @@ -667,7 +692,7 @@ type dentry struct { // tracks dirty segments in cache. dirty is protected by dataMu. dirty fsutil.DirtySet - // pf implements platform.File for mappings of handle.fd. + // pf implements platform.File for mappings of hostFD. pf dentryPlatformFile // If this dentry represents a symbolic link, InteropModeShared is not in @@ -687,6 +712,13 @@ type dentry struct { locks vfs.FileLocks // Inotify watches for this dentry. + // + // Note that inotify may behave unexpectedly in the presence of hard links, + // because dentries corresponding to the same file have separate inotify + // watches when they should share the same set. This is the case because it is + // impossible for us to know for sure whether two dentries correspond to the + // same underlying file (see the gofer filesystem section fo vfs/inotify.md for + // a more in-depth discussion on this matter). watches vfs.Watches } @@ -729,9 +761,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, - handle: handle{ - fd: -1, - }, + hostFD: -1, } d.pf.dentry = d if mask.UID { @@ -779,8 +809,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool { // updateFromP9Attrs is called to update d's metadata after an update from the // remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() +// Precondition: d.metadataMu must be locked. +func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if mask.Mode { if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { d.metadataMu.Unlock() @@ -798,10 +828,12 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { if attr.BlockSize != 0 { atomic.StoreUint32(&d.blockSize, uint32(attr.BlockSize)) } - if mask.ATime { + // Don't override newer client-defined timestamps with old server-defined + // ones. + if mask.ATime && atomic.LoadUint32(&d.atimeDirty) == 0 { atomic.StoreInt64(&d.atime, dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds)) } - if mask.MTime { + if mask.MTime && atomic.LoadUint32(&d.mtimeDirty) == 0 { atomic.StoreInt64(&d.mtime, dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds)) } if mask.CTime { @@ -814,23 +846,33 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { atomic.StoreUint32(&d.nlink, uint32(attr.NLink)) } if mask.Size { - d.updateFileSizeLocked(attr.Size) + d.updateSizeLocked(attr.Size) } - d.metadataMu.Unlock() } // Preconditions: !d.isSynthetic() func (d *dentry) updateFromGetattr(ctx context.Context) error { - // Use d.handle.file, which represents a 9P fid that has been opened, in - // preference to d.file, which represents a 9P fid that has not. This may - // be significantly more efficient in some implementations. + // Use d.readFile or d.writeFile, which represent 9P fids that have been + // opened, in preference to d.file, which represents a 9P fid that has not. + // This may be significantly more efficient in some implementations. Prefer + // d.writeFile over d.readFile since some filesystem implementations may + // update a writable handle's metadata after writes to that handle, without + // making metadata updates immediately visible to read-only handles + // representing the same file. var ( file p9file handleMuRLocked bool ) + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() d.handleMu.RLock() - if !d.handle.file.isNil() { - file = d.handle.file + if !d.writeFile.isNil() { + file = d.writeFile + handleMuRLocked = true + } else if !d.readFile.isNil() { + file = d.readFile handleMuRLocked = true } else { file = d.file @@ -843,7 +885,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { if err != nil { return err } - d.updateFromP9Attrs(attrMask, &attr) + d.updateFromP9AttrsLocked(attrMask, &attr) return nil } @@ -879,7 +921,8 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.DevMinor = d.fs.devMinor } -func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error { +func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -887,45 +930,49 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } if err := mnt.CheckBeginWrite(); err != nil { return err } defer mnt.EndWrite() - setLocalAtime := false - setLocalMtime := false + + if stat.Mask&linux.STATX_SIZE != 0 { + // Reject attempts to truncate files other than regular files, since + // filesystem implementations may return the wrong errno. + switch mode.FileType() { + case linux.S_IFREG: + // ok + case linux.S_IFDIR: + return syserror.EISDIR + default: + return syserror.EINVAL + } + } + + var now int64 if d.cachedMetadataAuthoritative() { - // Timestamp updates will be handled locally. - setLocalAtime = stat.Mask&linux.STATX_ATIME != 0 - setLocalMtime = stat.Mask&linux.STATX_MTIME != 0 - stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME - - // Prepare for truncate. - if stat.Mask&linux.STATX_SIZE != 0 { - switch d.mode & linux.S_IFMT { - case linux.S_IFREG: - if !setLocalMtime { - // Truncate updates mtime. - setLocalMtime = true - stat.Mtime.Nsec = linux.UTIME_NOW - } - case linux.S_IFDIR: - return syserror.EISDIR - default: - return syserror.EINVAL + // Truncate updates mtime. + if stat.Mask&(linux.STATX_SIZE|linux.STATX_MTIME) == linux.STATX_SIZE { + stat.Mask |= linux.STATX_MTIME + stat.Mtime = linux.StatxTimestamp{ + Nsec: linux.UTIME_NOW, } } + + // Use client clocks for timestamps. + now = d.fs.clock.Now().Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = statxTimestampFromDentry(now) + } + if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = statxTimestampFromDentry(now) + } } + d.metadataMu.Lock() defer d.metadataMu.Unlock() - if stat.Mask&linux.STATX_SIZE != 0 { - // The size needs to be changed even when - // !d.cachedMetadataAuthoritative() because d.mappings has to be - // updated. - d.updateFileSizeLocked(stat.Size) - } if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ @@ -949,6 +996,12 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin }); err != nil { return err } + if stat.Mask&linux.STATX_SIZE != 0 { + // d.size should be kept up to date, and privatized + // copy-on-write mappings of truncated pages need to be + // invalidated, even if InteropModeShared is in effect. + d.updateSizeLocked(stat.Size) + } } if d.fs.opts.interop == InteropModeShared { // There's no point to updating d's metadata in this case since @@ -958,7 +1011,6 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin return nil } } - now := d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } @@ -968,33 +1020,51 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin if stat.Mask&linux.STATX_GID != 0 { atomic.StoreUint32(&d.gid, stat.GID) } - if setLocalAtime { - if stat.Atime.Nsec == linux.UTIME_NOW { - atomic.StoreInt64(&d.atime, now) - } else { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) - } - // Restore mask bits that we cleared earlier. - stat.Mask |= linux.STATX_ATIME + // Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because + // if d.cachedMetadataAuthoritative() then we converted stat.Atime and + // stat.Mtime to client-local timestamps above, and if + // !d.cachedMetadataAuthoritative() then we returned after calling + // d.file.setAttr(). For the same reason, now must have been initialized. + if stat.Mask&linux.STATX_ATIME != 0 { + atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreUint32(&d.atimeDirty, 0) } - if setLocalMtime { - if stat.Mtime.Nsec == linux.UTIME_NOW { - atomic.StoreInt64(&d.mtime, now) - } else { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) - } - // Restore mask bits that we cleared earlier. - stat.Mask |= linux.STATX_MTIME + if stat.Mask&linux.STATX_MTIME != 0 { + atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) return nil } +// doAllocate performs an allocate operation on d. Note that d.metadataMu will +// be held when allocate is called. +func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate func() error) error { + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Allocating a smaller size is a noop. + size := offset + length + if d.cachedMetadataAuthoritative() && size <= d.size { + return nil + } + + err := allocate() + if err != nil { + return err + } + d.updateSizeLocked(size) + if d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + return nil +} + // Preconditions: d.metadataMu must be locked. -func (d *dentry) updateFileSizeLocked(newSize uint64) { +func (d *dentry) updateSizeLocked(newSize uint64) { d.dataMu.Lock() oldSize := d.size - d.size = newSize + atomic.StoreUint64(&d.size, newSize) // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings // below. This allows concurrent calls to Read/Translate/etc. These // functions synchronize with truncation by refusing to use cache @@ -1029,6 +1099,21 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + // We only support xattrs prefixed with "user." (see b/148380782). Currently, + // there is no need to expose any other xattrs through a gofer. + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + kuid := auth.KUID(atomic.LoadUint32(&d.uid)) + kgid := auth.KGID(atomic.LoadUint32(&d.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err + } + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) +} + func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid))) } @@ -1068,10 +1153,10 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { +func (d *dentry) DecRef(ctx context.Context) { if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { d.fs.renameMu.Lock() - d.checkCachingLocked() + d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() } else if refs < 0 { panic("gofer.dentry.DecRef() called without holding a reference") @@ -1088,7 +1173,7 @@ func (d *dentry) decRefLocked() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { if d.isDir() { events |= linux.IN_ISDIR } @@ -1096,9 +1181,9 @@ func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { d.fs.renameMu.RLock() // The ordering below is important, Linux always notifies the parent first. if d.parent != nil { - d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted()) + d.parent.watches.Notify(ctx, d.name, events, cookie, et, d.isDeleted()) } - d.watches.Notify("", events, cookie, et, d.isDeleted()) + d.watches.Notify(ctx, "", events, cookie, et, d.isDeleted()) d.fs.renameMu.RUnlock() } @@ -1110,10 +1195,10 @@ func (d *dentry) Watches() *vfs.Watches { // OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. // // If no watches are left on this dentry and it has no references, cache it. -func (d *dentry) OnZeroWatches() { +func (d *dentry) OnZeroWatches(ctx context.Context) { if atomic.LoadInt64(&d.refs) == 0 { d.fs.renameMu.Lock() - d.checkCachingLocked() + d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() } } @@ -1127,8 +1212,9 @@ func (d *dentry) OnZeroWatches() { // operation. One of the calls may destroy the dentry, so subsequent calls will // do nothing. // -// Preconditions: d.fs.renameMu must be locked for writing. -func (d *dentry) checkCachingLocked() { +// Preconditions: d.fs.renameMu must be locked for writing; it may be +// temporarily unlocked. +func (d *dentry) checkCachingLocked(ctx context.Context) { // Dentries with a non-zero reference count must be retained. (The only way // to obtain a reference on a dentry with zero references is via path // resolution, which requires renameMu, so if d.refs is zero then it will @@ -1150,14 +1236,14 @@ func (d *dentry) checkCachingLocked() { // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { if d.isDeleted() { - d.watches.HandleDeletion() + d.watches.HandleDeletion(ctx) } if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- d.cached = false } - d.destroyLocked() + d.destroyLocked(ctx) return } // If d still has inotify watches and it is not deleted or invalidated, we @@ -1192,7 +1278,7 @@ func (d *dentry) checkCachingLocked() { if !victim.vfsd.IsDead() { // Note that victim can't be a mount point (in any mount // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(&victim.vfsd) + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) delete(victim.parent.children, victim.name) // We're only deleting the dentry, not the file it // represents, so we don't need to update @@ -1200,19 +1286,21 @@ func (d *dentry) checkCachingLocked() { } victim.parent.dirMu.Unlock() } - victim.destroyLocked() + victim.destroyLocked(ctx) } // Whether or not victim was destroyed, we brought fs.cachedDentriesLen // back down to fs.opts.maxCachedDentries, so we don't loop. } } -// destroyLocked destroys the dentry. It may flushes dirty pages from cache, -// close p9 file and remove reference on parent dentry. +// destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is -// not a child dentry. -func (d *dentry) destroyLocked() { +// Preconditions: +// * d.fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * d.refs == 0. +// * d.parent.children[d.name] != d, i.e. d is not reachable by path traversal +// from its former parent dentry. +func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: // Mark the dentry destroyed. @@ -1223,39 +1311,68 @@ func (d *dentry) destroyLocked() { panic("dentry.destroyLocked() called with references on the dentry") } - ctx := context.Background() + // Allow the following to proceed without renameMu locked to improve + // scalability. + d.fs.renameMu.Unlock() + + mf := d.fs.mfp.MemoryFile() d.handleMu.Lock() - if !d.handle.file.isNil() { - mf := d.fs.mfp.MemoryFile() - d.dataMu.Lock() + d.dataMu.Lock() + if h := d.writeHandleLocked(); h.isOpen() { // Write dirty pages back to the remote filesystem. - if d.handleWritable { - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { - log.Warningf("gofer.dentry.DecRef: failed to write dirty data back: %v", err) - } + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to write dirty data back: %v", err) } - // Discard cached data. + } + // Discard cached data. + if !d.cache.IsEmpty() { + mf.MarkAllUnevictable(d) d.cache.DropAll(mf) d.dirty.RemoveAll() - d.dataMu.Unlock() - // Clunk open fids and close open host FDs. - d.handle.close(ctx) + } + d.dataMu.Unlock() + // Clunk open fids and close open host FDs. + if !d.readFile.isNil() { + d.readFile.close(ctx) + } + if !d.writeFile.isNil() && d.readFile != d.writeFile { + d.writeFile.close(ctx) + } + d.readFile = p9file{} + d.writeFile = p9file{} + if d.hostFD >= 0 { + syscall.Close(int(d.hostFD)) + d.hostFD = -1 } d.handleMu.Unlock() if !d.file.isNil() { - d.file.close(ctx) + // Note that it's possible that d.atimeDirty or d.mtimeDirty are true, + // i.e. client and server timestamps may differ (because e.g. a client + // write was serviced by the page cache, and only written back to the + // remote file later). Ideally, we'd write client timestamps back to + // the remote filesystem so that timestamps for a new dentry + // instantiated for the same file would remain coherent. Unfortunately, + // this turns out to be too expensive in many cases, so for now we + // don't do this. + if err := d.file.close(ctx); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) + } d.file = p9file{} + // Remove d from the set of syncable dentries. d.fs.syncMu.Lock() delete(d.fs.syncableDentries, d) d.fs.syncMu.Unlock() } + + d.fs.renameMu.Lock() + // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. if d.parent != nil { if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkCachingLocked() + d.parent.checkCachingLocked(ctx) } else if refs < 0 { panic("gofer.dentry.DecRef() called without holding a reference") } @@ -1270,9 +1387,7 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -// We only support xattrs prefixed with "user." (see b/148380782). Currently, -// there is no need to expose any other xattrs through a gofer. -func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { +func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { if d.file.isNil() || !d.userXattrSupported() { return nil, nil } @@ -1282,6 +1397,7 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui } xattrs := make([]string, 0, len(xattrMap)) for x := range xattrMap { + // We only support xattrs in the user.* namespace. if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { xattrs = append(xattrs, x) } @@ -1289,141 +1405,166 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui return xattrs, nil } -func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { +func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { if d.file.isNil() { return "", syserror.ENODATA } - if err := d.checkPermissions(creds, vfs.MayRead); err != nil { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return "", syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return "", syserror.ENODATA - } return d.file.getXattr(ctx, opts.Name, opts.Size) } -func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error { +func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { if d.file.isNil() { return syserror.EPERM } - if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return syserror.EPERM - } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } -func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error { +func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error { if d.file.isNil() { return syserror.EPERM } - if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return syserror.EPERM - } return d.file.removeXattr(ctx, name) } // Extended attributes in the user.* namespace are only supported for regular // files and directories. func (d *dentry) userXattrSupported() bool { - filetype := linux.S_IFMT & atomic.LoadUint32(&d.mode) - return filetype == linux.S_IFREG || filetype == linux.S_IFDIR + filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType() + return filetype == linux.ModeRegular || filetype == linux.ModeDirectory } -// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir(). +// Preconditions: +// * !d.isSynthetic(). +// * d.isRegularFile() || d.isDir(). func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error { // O_TRUNC unconditionally requires us to obtain a new handle (opened with // O_TRUNC). if !trunc { d.handleMu.RLock() - if (!read || d.handleReadable) && (!write || d.handleWritable) { - // The current handle is sufficient. + if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) { + // Current handles are sufficient. d.handleMu.RUnlock() return nil } d.handleMu.RUnlock() } - haveOldFD := false + fdToClose := int32(-1) + invalidateTranslations := false d.handleMu.Lock() - if (read && !d.handleReadable) || (write && !d.handleWritable) || trunc { - // Get a new handle. - wantReadable := d.handleReadable || read - wantWritable := d.handleWritable || write - h, err := openHandle(ctx, d.file, wantReadable, wantWritable, trunc) + if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc { + // Get a new handle. If this file has been opened for both reading and + // writing, try to get a single handle that is usable for both: + // + // - Writable memory mappings of a host FD require that the host FD is + // opened for both reading and writing. + // + // - NOTE(b/141991141): Some filesystems may not ensure coherence + // between multiple handles for the same file. + openReadable := !d.readFile.isNil() || read + openWritable := !d.writeFile.isNil() || write + h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc) + if err == syserror.EACCES && (openReadable != read || openWritable != write) { + // It may not be possible to use a single handle for both + // reading and writing, since permissions on the file may have + // changed to e.g. disallow reading after previously being + // opened for reading. In this case, we have no choice but to + // use separate handles for reading and writing. + ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d) + openReadable = read + openWritable = write + h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + } if err != nil { d.handleMu.Unlock() return err } - if !d.handle.file.isNil() { - // Check that old and new handles are compatible: If the old handle - // includes a host file descriptor but the new one does not, or - // vice versa, old and new memory mappings may be incoherent. - haveOldFD = d.handle.fd >= 0 - haveNewFD := h.fd >= 0 - if haveOldFD != haveNewFD { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: can't change host FD availability from %v to %v across dentry handle upgrade", haveOldFD, haveNewFD) - h.close(ctx) - return syserror.EIO - } - if haveOldFD { - // We may have raced with callers of d.pf.FD() that are now - // using the old file descriptor, preventing us from safely - // closing it. We could handle this by invalidating existing - // memmap.Translations, but this is expensive. Instead, use - // dup3 to make the old file descriptor refer to the new file - // description, then close the new file descriptor (which is no - // longer needed). Racing callers may use the old or new file - // description, but this doesn't matter since they refer to the - // same file (unless d.fs.opts.overlayfsStaleRead is true, - // which we handle separately). - if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil { + + if d.hostFD < 0 && h.fd >= 0 && openReadable && (d.writeFile.isNil() || openWritable) { + // We have no existing FD, and the new FD meets the requirements + // for d.hostFD, so start using it. + d.hostFD = h.fd + } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable { + // We have an existing read-only FD, but the file has just been + // opened for writing, so we need to start supporting writable memory + // mappings. This may race with callers of d.pf.FD() using the existing + // FD, so in most cases we need to delay closing the old FD until after + // invalidating memmap.Translations that might have observed it. + if !openReadable || h.fd < 0 { + // We don't have a read/write FD, so we have no FD that can be + // used to create writable memory mappings. Switch to using the + // internal page cache. + invalidateTranslations = true + fdToClose = d.hostFD + d.hostFD = -1 + } else if d.fs.opts.overlayfsStaleRead { + // We do have a read/write FD, but it may not be coherent with + // the existing read-only FD, so we must switch to mappings of + // the new FD in both the application and sentry. + if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) h.close(ctx) return err } - syscall.Close(int(h.fd)) - h.fd = d.handle.fd - if d.fs.opts.overlayfsStaleRead { - // Replace sentry mappings of the old FD with mappings of - // the new FD, since the two are not necessarily coherent. - if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) - h.close(ctx) - return err - } + invalidateTranslations = true + fdToClose = d.hostFD + d.hostFD = h.fd + } else { + // We do have a read/write FD. To avoid invalidating existing + // memmap.Translations (which is expensive), use dup3 to make + // the old file descriptor refer to the new file description, + // then close the new file descriptor (which is no longer + // needed). Racing callers of d.pf.FD() may use the old or new + // file description, but this doesn't matter since they refer + // to the same file, and any racing mappings must be read-only. + if err := syscall.Dup3(int(h.fd), int(d.hostFD), syscall.O_CLOEXEC); err != nil { + oldHostFD := d.hostFD + d.handleMu.Unlock() + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldHostFD, err) + h.close(ctx) + return err } - // Clunk the old fid before making the new handle visible (by - // unlocking d.handleMu). - d.handle.file.close(ctx) + fdToClose = h.fd } + } else { + // h.fd is not useful. + fdToClose = h.fd + } + + // Switch to new fids. + var oldReadFile p9file + if openReadable { + oldReadFile = d.readFile + d.readFile = h.file + } + var oldWriteFile p9file + if openWritable { + oldWriteFile = d.writeFile + d.writeFile = h.file + } + // NOTE(b/141991141): Clunk old fids before making new fids visible (by + // unlocking d.handleMu). + if !oldReadFile.isNil() { + oldReadFile.close(ctx) + } + if !oldWriteFile.isNil() && oldReadFile != oldWriteFile { + oldWriteFile.close(ctx) } - // Switch to the new handle. - d.handle = h - d.handleReadable = wantReadable - d.handleWritable = wantWritable } d.handleMu.Unlock() - if d.fs.opts.overlayfsStaleRead && haveOldFD { - // Invalidate application mappings that may be using the old FD; they + if invalidateTranslations { + // Invalidate application mappings that may be using an old FD; they // will be replaced with mappings using the new FD after future calls // to d.Translate(). This requires holding d.mapsMu, which precedes // d.handleMu in the lock order. @@ -1431,10 +1572,54 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool d.mappings.InvalidateAll(memmap.InvalidateOpts{}) d.mapsMu.Unlock() } + if fdToClose >= 0 { + syscall.Close(int(fdToClose)) + } return nil } +// Preconditions: d.handleMu must be locked. +func (d *dentry) readHandleLocked() handle { + return handle{ + file: d.readFile, + fd: d.hostFD, + } +} + +// Preconditions: d.handleMu must be locked. +func (d *dentry) writeHandleLocked() handle { + return handle{ + file: d.writeFile, + fd: d.hostFD, + } +} + +func (d *dentry) syncRemoteFile(ctx context.Context) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + return d.syncRemoteFileLocked(ctx) +} + +// Preconditions: d.handleMu must be locked. +func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if d.hostFD >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(d.hostFD)) + ctx.UninterruptibleSleepFinish(false) + return err + } + if !d.writeFile.isNil() { + return d.writeFile.fsync(ctx) + } + if !d.readFile.isNil() { + return d.readFile.fsync(ctx) + } + return nil +} + // incLinks increments link count. func (d *dentry) incLinks() { if atomic.LoadUint32(&d.nlink) == 0 { @@ -1455,12 +1640,14 @@ func (d *dentry) decLinks() { // fileDescription is embedded by gofer implementations of // vfs.FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once + lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. } func (fd *fileDescription) filesystem() *filesystem { @@ -1489,42 +1676,42 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()); err != nil { + if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil { return err } if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent) + fd.dentry().InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { - return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size) +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size) } -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { - return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts) +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.dentry().getXattr(ctx, auth.CredentialsFromContext(ctx), &opts) } -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { d := fd.dentry() - if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { + if err := d.setXattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { return err } - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d := fd.dentry() - if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { + if err := d.removeXattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { return err } - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index adff39490..bfe75dfe4 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -20,10 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/contexttest" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" ) func TestDestroyIdempotent(t *testing.T) { + ctx := contexttest.Context(t) fs := filesystem{ + mfp: pgalloc.MemoryFileProviderFromContext(ctx), syncableDentries: make(map[*dentry]struct{}), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. @@ -31,7 +34,6 @@ func TestDestroyIdempotent(t *testing.T) { }, } - ctx := contexttest.Context(t) attr := &p9.Attr{ Mode: p9.ModeRegular, } @@ -50,7 +52,9 @@ func TestDestroyIdempotent(t *testing.T) { } parent.cacheNewChildLocked(child, "child") - child.checkCachingLocked() + fs.renameMu.Lock() + defer fs.renameMu.Unlock() + child.checkCachingLocked(ctx) if got := atomic.LoadInt64(&child.refs); got != -1 { t.Fatalf("child.refs=%d, want: -1", got) } @@ -58,6 +62,6 @@ func TestDestroyIdempotent(t *testing.T) { if got := atomic.LoadInt64(&parent.refs); got != -1 { t.Fatalf("parent.refs=%d, want: -1", got) } - child.checkCachingLocked() - child.checkCachingLocked() + child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 8792ca4f2..a9ebe1206 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -25,6 +25,8 @@ import ( // handle represents a remote "open file descriptor", consisting of an opened // fid (p9.File) and optionally a host file descriptor. +// +// These are explicitly not savable. type handle struct { file p9file fd int32 // -1 if unavailable @@ -63,6 +65,10 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand }, nil } +func (h *handle) isOpen() bool { + return !h.file.isNil() +} + func (h *handle) close(ctx context.Context) { h.file.close(ctx) h.file = p9file{} @@ -124,18 +130,3 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o } return cp, cperr } - -func (h *handle) sync(ctx context.Context) error { - // Handle most common case first. - if h.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(h.fd)) - ctx.UninterruptibleSleepFinish(false) - return err - } - if h.file.isNil() { - // File hasn't been touched, there is nothing to sync. - return nil - } - return h.file.fsync(ctx) -} diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 87f0b877f..21b4a96fe 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -127,6 +127,13 @@ func (f p9file) close(ctx context.Context) error { return err } +func (f p9file) setAttrClose(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error { + ctx.UninterruptibleSleepStart(false) + err := f.file.SetAttrClose(valid, attr) + ctx.UninterruptibleSleepFinish(false) + return err +} + func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { ctx.UninterruptibleSleepStart(false) fdobj, qid, iounit, err := f.file.Open(flags) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index a2f02d9c7..eeaf6e444 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -40,16 +39,17 @@ func (d *dentry) isRegularFile() bool { return d.fileType() == linux.S_IFREG } +// +stateify savable type regularFileFD struct { fileDescription // off is the file offset. off is protected by mu. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { +func (fd *regularFileFD) Release(context.Context) { } // OnClose implements vfs.FileDescriptionImpl.OnClose. @@ -57,43 +57,34 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { if !fd.vfsfd.IsWritable() { return nil } - // Skip flushing if writes may be buffered by the client, since (as with - // the VFS1 client) we don't flush buffered writes on close anyway. + // Skip flushing if there are client-buffered writes, since (as with the + // VFS1 client) we don't flush buffered writes on close anyway. d := fd.dentry() - if d.fs.opts.interop == InteropModeExclusive { + if d.fs.opts.interop != InteropModeExclusive { + return nil + } + d.dataMu.RLock() + haveDirtyPages := !d.dirty.IsEmpty() + d.dataMu.RUnlock() + if haveDirtyPages { return nil } d.handleMu.RLock() defer d.handleMu.RUnlock() - return d.handle.file.flush(ctx) + if d.writeFile.isNil() { + return nil + } + return d.writeFile.flush(ctx) } // Allocate implements vfs.FileDescriptionImpl.Allocate. func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { - d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - - size := offset + length - - // Allocating a smaller size is a noop. - if size <= d.size { - return nil - } - - d.handleMu.Lock() - defer d.handleMu.Unlock() - - err := d.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) - if err != nil { - return err - } - d.size = size - if !d.cachedMetadataAuthoritative() { - d.touchCMtimeLocked() - } - return nil + return d.doAllocate(ctx, offset, length, func() error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) } // PRead implements vfs.FileDescriptionImpl.PRead. @@ -112,10 +103,14 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // Check for reading at EOF before calling into MM (but not under // InteropModeShared, which makes d.size unreliable). d := fd.dentry() - if d.fs.opts.interop != InteropModeShared && uint64(offset) >= atomic.LoadUint64(&d.size) { + if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) { return 0, io.EOF } + var ( + n int64 + readErr error + ) if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { // Lock d.metadataMu for the rest of the read to prevent d.size from // changing. @@ -126,20 +121,25 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs if err := d.writeback(ctx, offset, dst.NumBytes()); err != nil { return 0, err } - } - - rw := getDentryReadWriter(ctx, d, offset) - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { + rw := getDentryReadWriter(ctx, d, offset) // Require the read to go to the remote file. rw.direct = true + n, readErr = dst.CopyOutFrom(ctx, rw) + putDentryReadWriter(rw) + if d.fs.opts.interop != InteropModeShared { + // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). + d.touchAtimeLocked(fd.vfsfd.Mount()) + } + } else { + rw := getDentryReadWriter(ctx, d, offset) + n, readErr = dst.CopyOutFrom(ctx, rw) + putDentryReadWriter(rw) + if d.fs.opts.interop != InteropModeShared { + // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). + d.touchAtime(fd.vfsfd.Mount()) + } } - n, err := dst.CopyOutFrom(ctx, rw) - putDentryReadWriter(rw) - if d.fs.opts.interop != InteropModeShared { - // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). - d.touchAtime(fd.vfsfd.Mount()) - } - return n, err + return n, readErr } // Read implements vfs.FileDescriptionImpl.Read. @@ -153,90 +153,134 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } // Check that flags are supported. // // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. if opts.Flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP + return 0, offset, syserror.EOPNOTSUPP } + d := fd.dentry() + // If the fd was opened with O_APPEND, make sure the file size is updated. + // There is a possible race here if size is modified externally after + // metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } + } + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Set offset to file size if the fd was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) + } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) - d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() if d.fs.opts.interop != InteropModeShared { // Compare Linux's mm/filemap.c:__generic_file_write_iter() => // file_update_time(). This is d.touchCMtime(), but without locking // d.metadataMu (recursively). d.touchCMtimeLocked() } - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { - // Write dirty cached pages that will be touched by the write back to - // the remote file. - if err := d.writeback(ctx, offset, src.NumBytes()); err != nil { - return 0, err - } - // Remove touched pages from the cache. - pgstart := usermem.PageRoundDown(uint64(offset)) - pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes())) - if !ok { - return 0, syserror.EINVAL - } - mr := memmap.MappableRange{pgstart, pgend} - var freed []platform.FileRange - d.dataMu.Lock() - cseg := d.cache.LowerBoundSegment(mr.Start) - for cseg.Ok() && cseg.Start() < mr.End { - cseg = d.cache.Isolate(cseg, mr) - freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) - cseg = d.cache.Remove(cseg).NextSegment() - } - d.dataMu.Unlock() - // Invalidate mappings of removed pages. - d.mapsMu.Lock() - d.mappings.Invalidate(mr, memmap.InvalidateOpts{}) - d.mapsMu.Unlock() - // Finally free pages removed from the cache. - mf := d.fs.mfp.MemoryFile() - for _, freedFR := range freed { - mf.DecRef(freedFR) - } - } + rw := getDentryReadWriter(ctx, d, offset) + defer putDentryReadWriter(rw) + if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { + if err := fd.writeCache(ctx, d, offset, src); err != nil { + return 0, offset, err + } + // Require the write to go to the remote file. rw.direct = true } + n, err := src.CopyInTo(ctx, rw) - putDentryReadWriter(rw) - if n != 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 { + if err != nil { + return n, offset + n, err + } + if n > 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 { + // Note that if any of the following fail, then we can't guarantee that + // any data was actually written with the semantics of O_DSYNC or + // O_SYNC, so we return zero bytes written. Compare Linux's + // mm/filemap.c:generic_file_write_iter() => + // include/linux/fs.h:generic_write_sync(). + // // Write dirty cached pages touched by the write back to the remote // file. if err := d.writeback(ctx, offset, src.NumBytes()); err != nil { - return 0, err + return 0, offset, err } // Request the remote filesystem to sync the remote file. - if err := d.handle.file.fsync(ctx); err != nil { - return 0, err + if err := d.syncRemoteFile(ctx); err != nil { + return 0, offset, err } } - return n, err + return n, offset + n, nil +} + +func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64, src usermem.IOSequence) error { + // Write dirty cached pages that will be touched by the write back to + // the remote file. + if err := d.writeback(ctx, offset, src.NumBytes()); err != nil { + return err + } + + // Remove touched pages from the cache. + pgstart := usermem.PageRoundDown(uint64(offset)) + pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes())) + if !ok { + return syserror.EINVAL + } + mr := memmap.MappableRange{pgstart, pgend} + var freed []memmap.FileRange + + d.dataMu.Lock() + cseg := d.cache.LowerBoundSegment(mr.Start) + for cseg.Ok() && cseg.Start() < mr.End { + cseg = d.cache.Isolate(cseg, mr) + freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) + cseg = d.cache.Remove(cseg).NextSegment() + } + d.dataMu.Unlock() + + // Invalidate mappings of removed pages. + d.mapsMu.Lock() + d.mappings.Invalidate(mr, memmap.InvalidateOpts{}) + d.mapsMu.Unlock() + + // Finally free pages removed from the cache. + mf := d.fs.mfp.MemoryFile() + for _, freedFR := range freed { + mf.DecRef(freedFR) + } + return nil } // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -279,10 +323,11 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) // coherence with memory-mapped I/O), or if InteropModeShared is in effect // (which prevents us from caching file contents and makes dentry.size // unreliable), or if the file was opened O_DIRECT, read directly from - // dentry.handle without locking dentry.dataMu. + // dentry.readHandleLocked() without locking dentry.dataMu. rw.d.handleMu.RLock() - if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { - n, err := rw.d.handle.readToBlocksAt(rw.ctx, dsts, rw.off) + h := rw.d.readHandleLocked() + if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + n, err := h.readToBlocksAt(rw.ctx, dsts, rw.off) rw.d.handleMu.RUnlock() rw.off += n return n, err @@ -350,7 +395,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) End: gapEnd, } optMR := gap.Range() - err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt) + err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, h.readToBlocksAt) mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End}) seg, gap = rw.d.cache.Find(rw.off) if !seg.Ok() { @@ -365,7 +410,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) } else { // Read directly from the file. gapDsts := dsts.TakeFirst64(gapMR.Length()) - n, err := rw.d.handle.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start) + n, err := h.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start) done += n rw.off += n dsts = dsts.DropFirst64(n) @@ -397,11 +442,12 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro // If we have a mmappable host FD (which must be used here to ensure // coherence with memory-mapped I/O), or if InteropModeShared is in effect // (which prevents us from caching file contents), or if the file was - // opened with O_DIRECT, write directly to dentry.handle without locking - // dentry.dataMu. + // opened with O_DIRECT, write directly to dentry.writeHandleLocked() + // without locking dentry.dataMu. rw.d.handleMu.RLock() - if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { - n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off) + h := rw.d.writeHandleLocked() + if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + n, err := h.writeFromBlocksAt(rw.ctx, srcs, rw.off) rw.off += n rw.d.dataMu.Lock() if rw.off > rw.d.size { @@ -463,7 +509,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro // for detecting or avoiding this. gapMR := gap.Range().Intersect(mr) gapSrcs := srcs.TakeFirst64(gapMR.Length()) - n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start) + n, err := h.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start) done += n rw.off += n srcs = srcs.DropFirst64(n) @@ -489,7 +535,7 @@ exitLoop: if err := fsutil.SyncDirty(rw.ctx, memmap.MappableRange{ Start: start, End: rw.off, - }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, rw.d.handle.writeFromBlocksAt); err != nil { + }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, h.writeFromBlocksAt); err != nil { // We have no idea how many bytes were actually flushed. rw.off = start done = 0 @@ -507,6 +553,7 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error { } d.handleMu.RLock() defer d.handleMu.RUnlock() + h := d.writeHandleLocked() d.dataMu.Lock() defer d.dataMu.Unlock() // Compute the range of valid bytes (overflow-checked). @@ -520,7 +567,7 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error { return fsutil.SyncDirty(ctx, memmap.MappableRange{ Start: uint64(offset), End: uint64(end), - }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) } // Seek implements vfs.FileDescriptionImpl.Seek. @@ -577,24 +624,23 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncSharedHandle(ctx) + return fd.dentry().syncCachedFile(ctx) } -func (d *dentry) syncSharedHandle(ctx context.Context) error { +func (d *dentry) syncCachedFile(ctx context.Context) error { d.handleMu.RLock() defer d.handleMu.RUnlock() - if d.handleWritable { + if h := d.writeHandleLocked(); h.isOpen() { d.dataMu.Lock() // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) d.dataMu.Unlock() if err != nil { return err } } - // Sync the remote file. - return d.handle.sync(ctx) + return d.syncRemoteFileLocked(ctx) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -618,7 +664,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt return syserror.ENODEV } d.handleMu.RLock() - haveFD := d.handle.fd >= 0 + haveFD := d.hostFD >= 0 d.handleMu.RUnlock() if !haveFD { return syserror.ENODEV @@ -639,7 +685,7 @@ func (d *dentry) mayCachePages() bool { return true } d.handleMu.RLock() - haveFD := d.handle.fd >= 0 + haveFD := d.hostFD >= 0 d.handleMu.RUnlock() return haveFD } @@ -697,7 +743,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, // Translate implements memmap.Mappable.Translate. func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { d.handleMu.RLock() - if d.handle.fd >= 0 && !d.fs.opts.forcePageCache { + if d.hostFD >= 0 && !d.fs.opts.forcePageCache { d.handleMu.RUnlock() mr := optional if d.fs.opts.limitHostFDTranslation { @@ -733,7 +779,8 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab } mf := d.fs.mfp.MemoryFile() - cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, d.handle.readToBlocksAt) + h := d.readHandleLocked() + cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, h.readToBlocksAt) var ts []memmap.Translation var translatedEnd uint64 @@ -792,7 +839,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (d *dentry) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. d.mapsMu.Lock() @@ -802,9 +849,12 @@ func (d *dentry) InvalidateUnsavable(ctx context.Context) error { // Write the cache's contents back to the remote file so that if we have a // host fd after restore, the remote file's contents are coherent. mf := d.fs.mfp.MemoryFile() + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() d.dataMu.Lock() defer d.dataMu.Unlock() - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { return err } @@ -819,20 +869,23 @@ func (d *dentry) InvalidateUnsavable(ctx context.Context) error { // Evict implements pgalloc.EvictableMemoryUser.Evict. func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { + mr := memmap.MappableRange{er.Start, er.End} + mf := d.fs.mfp.MemoryFile() d.mapsMu.Lock() defer d.mapsMu.Unlock() + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() d.dataMu.Lock() defer d.dataMu.Unlock() - mr := memmap.MappableRange{er.Start, er.End} - mf := d.fs.mfp.MemoryFile() // Only allow pages that are no longer memory-mapped to be evicted. for mgap := d.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() { mgapMR := mgap.Range().Intersect(mr) if mgapMR.Length() == 0 { continue } - if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err) } d.cache.Drop(mgapMR, mf) @@ -840,53 +893,53 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { } } -// dentryPlatformFile implements platform.File. It exists solely because dentry -// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef. +// dentryPlatformFile implements memmap.File. It exists solely because dentry +// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file -// is available (i.e. dentry.handle.fd >= 0), and that FD is used for -// application memory mappings (i.e. !filesystem.opts.forcePageCache). +// is available (i.e. dentry.hostFD >= 0), and that FD is used for application +// memory mappings (i.e. !filesystem.opts.forcePageCache). +// +// +stateify savable type dentryPlatformFile struct { *dentry - // fdRefs counts references on platform.File offsets. fdRefs is protected + // fdRefs counts references on memmap.File offsets. fdRefs is protected // by dentry.dataMu. fdRefs fsutil.FrameRefSet - // If this dentry represents a regular file, and handle.fd >= 0, - // hostFileMapper caches mappings of handle.fd. + // If this dentry represents a regular file, and dentry.hostFD >= 0, + // hostFileMapper caches mappings of dentry.hostFD. hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once + hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. } -// IncRef implements platform.File.IncRef. -func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.IncRefAndAccount(fr) d.dataMu.Unlock() } -// DecRef implements platform.File.DecRef. -func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() - bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) - d.handleMu.RUnlock() - return bs, err + defer d.handleMu.RUnlock() + return d.hostFileMapper.MapInternal(fr, int(d.hostFD), at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() - fd := d.handle.fd - d.handleMu.RUnlock() - return int(fd) + defer d.handleMu.RUnlock() + return int(d.hostFD) } diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index d6dbe9092..326b940a7 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -36,12 +36,14 @@ func (d *dentry) isSocket() bool { // An endpoint's lifetime is the time between when filesystem.BoundEndpointAt() // is called and either BoundEndpoint.BidirectionalConnect or // BoundEndpoint.UnidirectionalConnect is called. +// +// +stateify savable type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry // file is the p9 file that contains a single unopened fid. - file p9.File + file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // path is the sentry path where this endpoint is bound. path string @@ -108,7 +110,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect // We don't need the receiver. c.CloseRecv() - c.Release() + c.Release(ctx) return c, nil } @@ -136,8 +138,8 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla } // Release implements transport.BoundEndpoint.Release. -func (e *endpoint) Release() { - e.dentry.DecRef() +func (e *endpoint) Release(ctx context.Context) { + e.dentry.DecRef(ctx) } // Passcred implements transport.BoundEndpoint.Passcred. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index c1e6b13e5..71581736c 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -16,10 +16,13 @@ package gofer import ( "sync" + "sync/atomic" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -28,17 +31,25 @@ import ( ) // specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device -// special files, and (when filesystemOptions.specialRegularFiles is in effect) -// regular files. specialFileFD differs from regularFileFD by using per-FD -// handles instead of shared per-dentry handles, and never buffering I/O. +// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is +// in effect) regular files. specialFileFD differs from regularFileFD by using +// per-FD handles instead of shared per-dentry handles, and never buffering I/O. +// +// +stateify savable type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle + handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + + // isRegularFile is true if this FD represents a regular file which is only + // possible when filesystemOptions.regularFilesUseSpecialFileFD is in + // effect. isRegularFile is immutable. + isRegularFile bool // seekable is true if this file description represents a file for which - // file offset is significant, i.e. a regular file. seekable is immutable. + // file offset is significant, i.e. a regular file, character device or + // block device. seekable is immutable. seekable bool // haveQueue is true if this file description represents a file for which @@ -47,18 +58,19 @@ type specialFileFD struct { queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 } func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() - seekable := ftype == linux.S_IFREG + seekable := ftype == linux.S_IFREG || ftype == linux.S_IFCHR || ftype == linux.S_IFBLK haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 fd := &specialFileFD{ - handle: h, - seekable: seekable, - haveQueue: haveQueue, + handle: h, + isRegularFile: ftype == linux.S_IFREG, + seekable: seekable, + haveQueue: haveQueue, } fd.LockFD.Init(locks) if haveQueue { @@ -79,11 +91,11 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *specialFileFD) Release() { +func (fd *specialFileFD) Release(ctx context.Context) { if fd.haveQueue { fdnotifier.RemoveFD(fd.handle.fd) } - fd.handle.close(context.Background()) + fd.handle.close(ctx) fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) fs.syncMu.Lock() delete(fs.specialFileFDs, fd) @@ -126,6 +138,16 @@ func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { fd.fileDescription.EventUnregister(e) } +func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + if fd.isRegularFile { + d := fd.dentry() + return d.doAllocate(ctx, offset, length, func() error { + return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) + } + return fd.FileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if fd.seekable && offset < 0 { @@ -144,7 +166,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't // hold here since specialFileFD doesn't client-cache data. Just buffer the // read instead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } buf := make([]byte, dst.NumBytes()) @@ -176,39 +198,82 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if fd.seekable && offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } // Check that flags are supported. // // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. if opts.Flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP + return 0, offset, syserror.EOPNOTSUPP + } + + d := fd.dentry() + // If the regular file fd was opened with O_APPEND, make sure the file size + // is updated. There is a possible race here if size is modified externally + // after metadata cache is updated. + if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } - if fd.seekable { + if fd.isRegularFile { + // We need to hold the metadataMu *while* writing to a regular file. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Set offset to file size if the regular file was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) + } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) } // Do a buffered write. See rationale in PRead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d.cachedMetadataAuthoritative() { d.touchCMtime() } buf := make([]byte, src.NumBytes()) - // Don't do partial writes if we get a partial read from src. - if _, err := src.CopyIn(ctx, buf); err != nil { - return 0, err + copied, copyErr := src.CopyIn(ctx, buf) + if copied == 0 && copyErr != nil { + // Only return the error if we didn't get any data. + return 0, offset, copyErr } - n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } - return int64(n), err + // Update offset if the offset is valid. + if offset >= 0 { + offset += int64(n) + } + // Update file size for regular files. + if fd.isRegularFile { + // d.metadataMu is already locked at this point. + if uint64(offset) > d.size { + d.dataMu.Lock() + defer d.dataMu.Unlock() + atomic.StoreUint64(&d.size, uint64(offset)) + } + } + if err != nil { + return int64(n), offset, err + } + return int64(n), offset, copyErr } // Write implements vfs.FileDescriptionImpl.Write. @@ -218,8 +283,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts } fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -241,5 +306,13 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncSharedHandle(ctx) + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 0eef4e16e..7e825caae 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -38,7 +38,7 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { - if mnt.Flags.NoATime { + if mnt.Flags.NoATime || mnt.ReadOnly() { return } if err := mnt.CheckBeginWrite(); err != nil { @@ -47,12 +47,28 @@ func (d *dentry) touchAtime(mnt *vfs.Mount) { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() atomic.StoreInt64(&d.atime, now) + atomic.StoreUint32(&d.atimeDirty, 1) d.metadataMu.Unlock() mnt.EndWrite() } -// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has -// successfully called vfs.Mount.CheckBeginWrite(). +// Preconditions: d.metadataMu is locked. d.cachedMetadataAuthoritative() == true. +func (d *dentry) touchAtimeLocked(mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + now := d.fs.clock.Now().Nanoseconds() + atomic.StoreInt64(&d.atime, now) + atomic.StoreUint32(&d.atimeDirty, 1) + mnt.EndWrite() +} + +// Preconditions: +// * d.cachedMetadataAuthoritative() == true. +// * The caller has successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -60,20 +76,24 @@ func (d *dentry) touchCtime() { d.metadataMu.Unlock() } -// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has -// successfully called vfs.Mount.CheckBeginWrite(). +// Preconditions: +// * d.cachedMetadataAuthoritative() == true. +// * The caller has successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCMtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() atomic.StoreInt64(&d.mtime, now) atomic.StoreInt64(&d.ctime, now) + atomic.StoreUint32(&d.mtimeDirty, 1) d.metadataMu.Unlock() } -// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has -// locked d.metadataMu. +// Preconditions: +// * d.cachedMetadataAuthoritative() == true. +// * The caller has locked d.metadataMu. func (d *dentry) touchCMtimeLocked() { now := d.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&d.mtime, now) atomic.StoreInt64(&d.ctime, now) + atomic.StoreUint32(&d.mtimeDirty, 1) } diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 44a09d87a..56bcf9bdb 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -1,12 +1,37 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "inode_refs", + out = "inode_refs.go", + package = "host", + prefix = "inode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "inode", + }, +) + +go_template_instance( + name = "connected_endpoint_refs", + out = "connected_endpoint_refs.go", + package = "host", + prefix = "ConnectedEndpoint", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "ConnectedEndpoint", + }, +) + go_library( name = "host", srcs = [ + "connected_endpoint_refs.go", "control.go", "host.go", + "inode_refs.go", "ioctl_unsafe.go", "mmap.go", "socket.go", @@ -22,7 +47,9 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/fspath", + "//pkg/iovec", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", @@ -33,7 +60,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index b9082a20f..0135e4428 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -58,7 +58,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (c *scmRights) Release() { +func (c *scmRights) Release(ctx context.Context) { for _, fd := range c.fds { syscall.Close(fd) } diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 1cd2982cb..ffe4ddb32 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" @@ -41,6 +40,44 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. If not, this syscall will return ESPIPE + // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character + // devices. + _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) + seekable := err != syserror.ESPIPE + + i := &inode{ + hostFD: hostFD, + ino: fs.NextIno(), + isTTY: isTTY, + wouldBlock: wouldBlock(uint32(fileType)), + seekable: seekable, + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + canMap: fileType == linux.S_IFREG, + } + i.pf.inode = i + i.EnableLeakCheck() + + // Non-seekable files can't be memory mapped, assert this. + if !i.seekable && i.canMap { + panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") + } + + // If the hostFD would block, we must set it to non-blocking and handle + // blocking behavior in the sentry. + if i.wouldBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + return nil, err + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + return nil, err + } + } + return i, nil +} + // NewFDOptions contains options to NewFD. type NewFDOptions struct { // If IsTTY is true, the file descriptor is a TTY. @@ -76,53 +113,20 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) flags = uint32(flagsInt) } - fileMode := linux.FileMode(s.Mode) - fileType := fileMode.FileType() - - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. - _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) - seekable := err != syserror.ESPIPE - - i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: opts.IsTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, - } - i.pf.inode = i - - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { - if err := syscall.SetNonblock(i.hostFD, true); err != nil { - return nil, err - } - if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { - return nil, err - } - } - d := &kernfs.Dentry{} + i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + if err != nil { + return nil, err + } d.Init(i) // i.open will take a reference on d. - defer d.DecRef() + defer d.DecRef(ctx) // For simplicity, fileDescription.offset is set to 0. Technically, we // should only set to 0 on files that are not seekable (sockets, pipes, // etc.), and use the offset from the host fd otherwise when importing. - return i.open(ctx, d.VFSDentry(), mnt, flags) + return i.open(ctx, d, mnt, flags) } // ImportFD sets up and returns a vfs.FileDescription from a donated fd. @@ -133,14 +137,16 @@ func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs } // filesystemType implements vfs.FilesystemType. +// +// +stateify savable type filesystemType struct{} -// GetFilesystem implements FilesystemType.GetFilesystem. +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { panic("host.filesystemType.GetFilesystem should never be called") } -// Name implements FilesystemType.Name. +// Name implements vfs.FilesystemType.Name. func (filesystemType) Name() string { return "none" } @@ -162,15 +168,17 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { } // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { kernfs.Filesystem devMinor uint32 } -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { @@ -181,14 +189,17 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { + kernfs.InodeNoStatFS kernfs.InodeNotDirectory kernfs.InodeNotSymlink locks vfs.FileLocks // When the reference count reaches zero, the host fd is closed. - refs.AtomicRefCount + inodeRefs // hostFD contains the host fd that this file was originally created from, // which must be available at time of restore. @@ -228,7 +239,7 @@ type inode struct { canMap bool // mapsMu protects mappings. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` // If canMap is true, mappings tracks mappings of hostFD into // memmap.MappingSpaces. @@ -238,7 +249,7 @@ type inode struct { pf inodePlatformFile } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { @@ -247,7 +258,7 @@ func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, a return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid)) } -// Mode implements kernfs.Inode. +// Mode implements kernfs.Inode.Mode. func (i *inode) Mode() linux.FileMode { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { @@ -258,8 +269,8 @@ func (i *inode) Mode() linux.FileMode { return linux.FileMode(s.Mode) } -// Stat implements kernfs.Inode. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +// Stat implements kernfs.Inode.Stat. +func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL } @@ -371,9 +382,9 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { }, nil } -// SetStat implements kernfs.Inode. +// SetStat implements kernfs.Inode.SetStat. func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - s := opts.Stat + s := &opts.Stat m := s.Mask if m == 0 { @@ -386,7 +397,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Fstat(i.hostFD, &hostStat); err != nil { return err } - if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { return err } @@ -396,6 +407,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } } if m&linux.STATX_SIZE != 0 { + if hostStat.Mode&linux.S_IFMT != linux.S_IFREG { + return syserror.EINVAL + } if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } @@ -427,31 +441,28 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre return nil } -// DecRef implements kernfs.Inode. -func (i *inode) DecRef() { - i.AtomicRefCount.DecRefWithDestructor(i.Destroy) -} - -// Destroy implements kernfs.Inode. -func (i *inode) Destroy() { - if i.wouldBlock { - fdnotifier.RemoveFD(int32(i.hostFD)) - } - if err := unix.Close(i.hostFD); err != nil { - log.Warningf("failed to close host fd %d: %v", i.hostFD, err) - } +// DecRef implements kernfs.Inode.DecRef. +func (i *inode) DecRef(ctx context.Context) { + i.inodeRefs.DecRef(func() { + if i.wouldBlock { + fdnotifier.RemoveFD(int32(i.hostFD)) + } + if err := unix.Close(i.hostFD); err != nil { + log.Warningf("failed to close host fd %d: %v", i.hostFD, err) + } + }) } -// Open implements kernfs.Inode. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +// Open implements kernfs.Inode.Open. +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/. if i.Mode().FileType() == linux.S_IFSOCK { return nil, syserror.ENXIO } - return i.open(ctx, vfsd, rp.Mount(), opts.Flags) + return i.open(ctx, d, rp.Mount(), opts.Flags) } -func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) { +func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { return nil, err @@ -475,17 +486,17 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u return nil, err } // Currently, we only allow Unix sockets to be imported. - return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks) + return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d.VFSDentry(), &i.locks) case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR: if i.isTTY { fd := &TTYFileDescription{ fileDescription: fileDescription{inode: i}, - termios: linux.DefaultSlaveTermios, + termios: linux.DefaultReplicaTermios, } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd - if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil @@ -494,7 +505,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u fd := &fileDescription{inode: i} fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd - if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil @@ -506,6 +517,8 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u } // fileDescription is embedded by host fd implementations of FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -520,40 +533,35 @@ type fileDescription struct { inode *inode // offsetMu protects offset. - offsetMu sync.Mutex + offsetMu sync.Mutex `state:"nosave"` // offset specifies the current file offset. It is only meaningful when // inode.seekable is true. offset int64 } -// SetStat implements vfs.FileDescriptionImpl. +// SetStat implements vfs.FileDescriptionImpl.SetStat. func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts) } -// Stat implements vfs.FileDescriptionImpl. -func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { - return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts) +// Stat implements vfs.FileDescriptionImpl.Stat. +func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts) } -// Release implements vfs.FileDescriptionImpl. -func (f *fileDescription) Release() { +// Release implements vfs.FileDescriptionImpl.Release. +func (f *fileDescription) Release(context.Context) { // noop } -// Allocate implements vfs.FileDescriptionImpl. +// Allocate implements vfs.FileDescriptionImpl.Allocate. func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { - if !f.inode.seekable { - return syserror.ESPIPE - } - - // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds. - return syserror.EOPNOTSUPP + return unix.Fallocate(f.inode.hostFD, uint32(mode), int64(offset), int64(length)) } -// PRead implements FileDescriptionImpl. +// PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { i := f.inode if !i.seekable { @@ -563,7 +571,7 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags) } -// Read implements FileDescriptionImpl. +// Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { i := f.inode if !i.seekable { @@ -600,7 +608,7 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off return int64(n), err } -// PWrite implements FileDescriptionImpl. +// PWrite implements vfs.FileDescriptionImpl.PWrite. func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { if !f.inode.seekable { return 0, syserror.ESPIPE @@ -609,7 +617,7 @@ func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, of return f.writeToHostFD(ctx, src, offset, opts.Flags) } -// Write implements FileDescriptionImpl. +// Write implements vfs.FileDescriptionImpl.Write. func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { i := f.inode if !i.seekable { @@ -657,7 +665,7 @@ func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSeque return int64(n), err } -// Seek implements FileDescriptionImpl. +// Seek implements vfs.FileDescriptionImpl.Seek. // // Note that we do not support seeking on directories, since we do not even // allow directory fds to be imported at all. @@ -722,13 +730,13 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return f.offset, nil } -// Sync implements FileDescriptionImpl. +// Sync implements vfs.FileDescriptionImpl.Sync. func (f *fileDescription) Sync(context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } -// ConfigureMMap implements FileDescriptionImpl. +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { if !f.inode.canMap { return syserror.ENODEV diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index 8545a82f0..b51a17bed 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -19,22 +19,23 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// inodePlatformFile implements platform.File. It exists solely because inode -// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// inodePlatformFile implements memmap.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // // inodePlatformFile should only be used if inode.canMap is true. +// +// +stateify savable type inodePlatformFile struct { *inode // fdRefsMu protects fdRefs. - fdRefsMu sync.Mutex + fdRefsMu sync.Mutex `state:"nosave"` - // fdRefs counts references on platform.File offsets. It is used solely for + // fdRefs counts references on memmap.File offsets. It is used solely for // memory accounting. fdRefs fsutil.FrameRefSet @@ -42,35 +43,35 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once + fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. } -// IncRef implements platform.File.IncRef. +// IncRef implements memmap.File.IncRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) IncRef(fr platform.FileRange) { +func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) i.fdRefsMu.Unlock() } -// DecRef implements platform.File.DecRef. +// DecRef implements memmap.File.DecRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) DecRef(fr platform.FileRange) { +func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) i.fdRefsMu.Unlock() } -// MapInternal implements platform.File.MapInternal. +// MapInternal implements memmap.File.MapInternal. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (i *inodePlatformFile) FD() int { return i.hostFD } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index fd16bd92d..8a447e29f 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" @@ -59,8 +58,7 @@ func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transpor // // +stateify savable type ConnectedEndpoint struct { - // ref keeps track of references to a ConnectedEndpoint. - ref refs.AtomicRefCount + ConnectedEndpointRefs // mu protects fd below. mu sync.RWMutex `state:"nosave"` @@ -132,14 +130,14 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable return nil, err } - // AtomicRefCounters start off with a single reference. We need two. - e.ref.IncRef() - e.ref.EnableLeakCheck("host.ConnectedEndpoint") + // ConnectedEndpointRefs start off with a single reference. We need two. + e.IncRef() + e.EnableLeakCheck() return &e, nil } // Send implements transport.ConnectedEndpoint.Send. -func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -216,7 +214,7 @@ func (c *ConnectedEndpoint) EventUpdate() { } // Recv implements transport.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -317,8 +315,8 @@ func (c *ConnectedEndpoint) destroyLocked() { // Release implements transport.ConnectedEndpoint.Release and // transport.Receiver.Release. -func (c *ConnectedEndpoint) Release() { - c.ref.DecRefWithDestructor(func() { +func (c *ConnectedEndpoint) Release(ctx context.Context) { + c.DecRef(func() { c.mu.Lock() c.destroyLocked() c.mu.Unlock() @@ -347,13 +345,13 @@ func (e *SCMConnectedEndpoint) Init() error { // Release implements transport.ConnectedEndpoint.Release and // transport.Receiver.Release. -func (e *SCMConnectedEndpoint) Release() { - e.ref.DecRefWithDestructor(func() { +func (e *SCMConnectedEndpoint) Release(ctx context.Context) { + e.DecRef(func() { e.mu.Lock() + fdnotifier.RemoveFD(int32(e.fd)) if err := syscall.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) } - fdnotifier.RemoveFD(int32(e.fd)) e.destroyLocked() e.mu.Unlock() }) @@ -378,8 +376,8 @@ func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr s return nil, err } - // AtomicRefCounters start off with a single reference. We need two. - e.ref.IncRef() - e.ref.EnableLeakCheck("host.SCMConnectedEndpoint") + // ConnectedEndpointRefs start off with a single reference. We need two. + e.IncRef() + e.EnableLeakCheck() return &e, nil } diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go index 584c247d2..fc0d5fd38 100644 --- a/pkg/sentry/fsimpl/host/socket_iovec.go +++ b/pkg/sentry/fsimpl/host/socket_iovec.go @@ -17,13 +17,10 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -74,7 +71,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go index 35ded24bc..c0bf45f08 100644 --- a/pkg/sentry/fsimpl/host/socket_unsafe.go +++ b/pkg/sentry/fsimpl/host/socket_unsafe.go @@ -63,10 +63,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) ( controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC if n > length { - return length, n, msg.Controllen, controlTrunc, err + return length, n, msg.Controllen, controlTrunc, nil } - return n, n, msg.Controllen, controlTrunc, err + return n, n, msg.Controllen, controlTrunc, nil } // fdWriteVec sends from bufs to fd. diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 4ee9270cc..f5c596fec 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -29,6 +30,8 @@ import ( // TTYFileDescription implements vfs.FileDescriptionImpl for a host file // descriptor that wraps a TTY FD. +// +// +stateify savable type TTYFileDescription struct { fileDescription @@ -67,15 +70,15 @@ func (t *TTYFileDescription) ForegroundProcessGroup() *kernel.ProcessGroup { } // Release implements fs.FileOperations.Release. -func (t *TTYFileDescription) Release() { +func (t *TTYFileDescription) Release(ctx context.Context) { t.mu.Lock() t.fgProcessGroup = nil t.mu.Unlock() - t.fileDescription.Release() + t.fileDescription.Release(ctx) } -// PRead implements vfs.FileDescriptionImpl. +// PRead implements vfs.FileDescriptionImpl.PRead. // // Reading from a TTY is only allowed for foreground process groups. Background // process groups will either get EIO or a SIGTTIN. @@ -93,7 +96,7 @@ func (t *TTYFileDescription) PRead(ctx context.Context, dst usermem.IOSequence, return t.fileDescription.PRead(ctx, dst, offset, opts) } -// Read implements vfs.FileDescriptionImpl. +// Read implements vfs.FileDescriptionImpl.Read. // // Reading from a TTY is only allowed for foreground process groups. Background // process groups will either get EIO or a SIGTTIN. @@ -111,7 +114,7 @@ func (t *TTYFileDescription) Read(ctx context.Context, dst usermem.IOSequence, o return t.fileDescription.Read(ctx, dst, opts) } -// PWrite implements vfs.FileDescriptionImpl. +// PWrite implements vfs.FileDescriptionImpl.PWrite. func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { t.mu.Lock() defer t.mu.Unlock() @@ -126,7 +129,7 @@ func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, return t.fileDescription.PWrite(ctx, src, offset, opts) } -// Write implements vfs.FileDescriptionImpl. +// Write implements vfs.FileDescriptionImpl.Write. func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { t.mu.Lock() defer t.mu.Unlock() @@ -141,8 +144,13 @@ func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, return t.fileDescription.Write(ctx, src, opts) } -// Ioctl implements vfs.FileDescriptionImpl. +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + return 0, syserror.ENOTTY + } + // Ignore arg[0]. This is the real FD: fd := t.inode.hostFD ioctl := args[1].Uint64() @@ -152,9 +160,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = termios.CopyOut(task, args[2].Pointer()) return 0, err case linux.TCSETS, linux.TCSETSW, linux.TCSETSF: @@ -166,9 +172,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch } var termios linux.Termios - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := termios.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetTermios(fd, ioctl, &termios) @@ -192,10 +196,8 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch defer t.mu.Unlock() // Map the ProcessGroup into a ProcessGroupID in the task's PID namespace. - pgID := pidns.IDOfProcessGroup(t.fgProcessGroup) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }) + pgID := primitive.Int32(pidns.IDOfProcessGroup(t.fgProcessGroup)) + _, err := pgID.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSPGRP: @@ -203,11 +205,6 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch // Equivalent to tcsetpgrp(fd, *argp). // Set the foreground process group ID of this terminal. - task := kernel.TaskFromContext(ctx) - if task == nil { - return 0, syserror.ENOTTY - } - t.mu.Lock() defer t.mu.Unlock() @@ -226,12 +223,11 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch return 0, syserror.ENOTTY } - var pgID kernel.ProcessGroupID - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgIDP primitive.Int32 + if _, err := pgIDP.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } + pgID := kernel.ProcessGroupID(pgIDP) // pgID must be non-negative. if pgID < 0 { @@ -260,9 +256,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = winsize.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSWINSZ: @@ -273,9 +267,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch // set the winsize. var winsize linux.Winsize - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := winsize.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetWinsize(fd, &winsize) @@ -376,7 +368,7 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) // // Linux ignores the result of kill_pgrp(). _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) - return kernel.ERESTARTSYS + return syserror.ERESTARTSYS } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 179df6c1e..5e91e0536 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -26,9 +26,54 @@ go_template_instance( }, ) +go_template_instance( + name = "dentry_refs", + out = "dentry_refs.go", + package = "kernfs", + prefix = "Dentry", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Dentry", + }, +) + +go_template_instance( + name = "static_directory_refs", + out = "static_directory_refs.go", + package = "kernfs", + prefix = "StaticDirectory", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "StaticDirectory", + }, +) + +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "kernfs_test", + prefix = "dir", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "dir", + }, +) + +go_template_instance( + name = "readonly_dir_refs", + out = "readonly_dir_refs.go", + package = "kernfs_test", + prefix = "readonlyDir", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "readonlyDir", + }, +) + go_library( name = "kernfs", srcs = [ + "dentry_refs.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", @@ -36,7 +81,9 @@ go_library( "inode_impl_util.go", "kernfs.go", "slot_list.go", + "static_directory_refs.go", "symlink.go", + "synthetic_directory.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -59,17 +106,23 @@ go_library( go_test( name = "kernfs_test", size = "small", - srcs = ["kernfs_test.go"], + srcs = [ + "dir_refs.go", + "kernfs_test.go", + "readonly_dir_refs.go", + ], deps = [ ":kernfs", "//pkg/abi/linux", "//pkg/context", + "//pkg/log", + "//pkg/refs", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 6886b0876..b929118b1 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -35,6 +35,7 @@ import ( // +stateify savable type DynamicBytesFile struct { InodeAttrs + InodeNoStatFS InodeNoopRefCount InodeNotDirectory InodeNotSymlink @@ -55,9 +56,9 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint } // Open implements Inode.Open. -func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &DynamicBytesFD{} - if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil { + if err := fd.Init(rp.Mount(), d, f.data, &f.locks, opts.Flags); err != nil { return nil, err } return &fd.vfsfd, nil @@ -86,12 +87,12 @@ type DynamicBytesFD struct { } // Init initializes a DynamicBytesFD. -func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error { +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error { fd.LockFD.Init(locks) - if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return err } - fd.inode = d.Impl().(*Dentry).inode + fd.inode = d.inode fd.SetDataSource(data) return nil } @@ -122,12 +123,12 @@ func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, of } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DynamicBytesFD) Release() {} +func (fd *DynamicBytesFD) Release(context.Context) {} // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(fs, opts) + return fd.inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index ca8b8c63b..0a4cd4057 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -15,7 +15,7 @@ package kernfs import ( - "math" + "fmt" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -28,9 +28,29 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// SeekEndConfig describes the SEEK_END behaviour for FDs. +// +// +stateify savable +type SeekEndConfig int + +// Constants related to SEEK_END behaviour for FDs. +const ( + // Consider the end of the file to be after the final static entry. This is + // the default option. + SeekEndStaticEntries = iota + // Consider the end of the file to be at offset 0. + SeekEndZero +) + +// GenericDirectoryFDOptions contains configuration for a GenericDirectoryFD. +// +// +stateify savable +type GenericDirectoryFDOptions struct { + SeekEnd SeekEndConfig +} + // GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory -// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not -// compatible with dynamic directories. +// inode that uses OrderChildren to track child nodes. // // Note that GenericDirectoryFD holds a lock over OrderedChildren while calling // IterDirents callback. The IterDirents callback therefore cannot hash or @@ -40,16 +60,21 @@ import ( // Must be initialize with Init before first use. // // Lock ordering: mu => children.mu. +// +// +stateify savable type GenericDirectoryFD struct { vfs.FileDescriptionDefaultImpl vfs.DirectoryFileDescriptionDefaultImpl vfs.LockFD + // Immutable. + seekEnd SeekEndConfig + vfsfd vfs.FileDescription children *OrderedChildren // mu protects the fields below. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // off is the current directory offset. Protected by "mu". off int64 @@ -57,12 +82,12 @@ type GenericDirectoryFD struct { // NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its // dentry. -func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { +func NewGenericDirectoryFD(m *vfs.Mount, d *Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) (*GenericDirectoryFD, error) { fd := &GenericDirectoryFD{} - if err := fd.Init(children, locks, opts); err != nil { + if err := fd.Init(children, locks, opts, fdOpts); err != nil { return nil, err } - if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(fd, opts.Flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return fd, nil @@ -71,12 +96,13 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre // Init initializes a GenericDirectoryFD. Use it when overriding // GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the // correct implementation. -func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error { +func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) error { if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { // Can't open directories for writing. return syserror.EISDIR } fd.LockFD.Init(locks) + fd.seekEnd = fdOpts.SeekEnd fd.children = children return nil } @@ -112,8 +138,8 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) } -// Release implements vfs.FileDecriptionImpl.Release. -func (fd *GenericDirectoryFD) Release() {} +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *GenericDirectoryFD) Release(context.Context) {} func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() @@ -123,7 +149,7 @@ func (fd *GenericDirectoryFD) inode() Inode { return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode } -// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds // o.mu when calling cb. func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { fd.mu.Lock() @@ -132,7 +158,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent opts := vfs.StatOptions{Mask: linux.STATX_INO} // Handle ".". if fd.off == 0 { - stat, err := fd.inode().Stat(fd.filesystem(), opts) + stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -152,7 +178,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent if fd.off == 1 { vfsd := fd.vfsfd.VirtualDentry().Dentry() parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode - stat, err := parentInode.Stat(fd.filesystem(), opts) + stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -175,8 +201,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // these. childIdx := fd.off - 2 for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { - inode := it.Dentry.Impl().(*Dentry).inode - stat, err := inode.Stat(fd.filesystem(), opts) + stat, err := it.Dentry.inode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -198,7 +223,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent return err } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() defer fd.mu.Unlock() @@ -209,9 +234,17 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int case linux.SEEK_CUR: offset += fd.off case linux.SEEK_END: - // TODO(gvisor.dev/issue/1193): This can prevent new files from showing up - // if they are added after SEEK_END. - offset = math.MaxInt64 + switch fd.seekEnd { + case SeekEndStaticEntries: + fd.children.mu.RLock() + offset += int64(len(fd.children.set)) + offset += 2 // '.' and '..' aren't tracked in children. + fd.children.mu.RUnlock() + case SeekEndZero: + // No-op: offset += 0. + default: + panic(fmt.Sprintf("Invalid GenericDirectoryFD.seekEnd = %v", fd.seekEnd)) + } default: return 0, syserror.EINVAL } @@ -226,7 +259,7 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.filesystem() inode := fd.inode() - return inode.Stat(fs, opts) + return inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 8939871c1..5cc1c4281 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -32,11 +32,12 @@ import ( // // stepExistingLocked is loosely analogous to fs/namei.c:walk_component(). // -// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * !rp.Done(). // // Postcondition: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) { - d := vfsd.Impl().(*Dentry) +func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, mayFollowSymlinks bool) (*Dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR } @@ -53,20 +54,20 @@ afterSymlink: // calls d_revalidate(), but walk_component() => handle_dots() does not. if name == "." { rp.Advance() - return vfsd, nil + return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, d.VFSDentry()); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() - return vfsd, nil + return d, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, d.parent.VFSDentry()); err != nil { return nil, err } rp.Advance() - return &d.parent.vfsd, nil + return d.parent, nil } if len(name) > linux.NAME_MAX { return nil, syserror.ENAMETOOLONG @@ -77,7 +78,7 @@ afterSymlink: if err != nil { return nil, err } - if err := rp.CheckMount(&next.vfsd); err != nil { + if err := rp.CheckMount(ctx, next.VFSDentry()); err != nil { return nil, err } // Resolve any symlink at current path component. @@ -88,7 +89,7 @@ afterSymlink: } if targetVD.Ok() { err := rp.HandleJump(targetVD) - targetVD.DecRef() + targetVD.DecRef(ctx) if err != nil { return nil, err } @@ -100,15 +101,18 @@ afterSymlink: goto afterSymlink } rp.Advance() - return &next.vfsd, nil + return next, nil } // revalidateChildLocked must be called after a call to parent.vfsd.Child(name) // or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be // nil) to verify that the returned child (or lack thereof) is correct. // -// Preconditions: Filesystem.mu must be locked for at least reading. -// parent.dirMu must be locked. parent.isDir(). name is not "." or "..". +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * parent.dirMu must be locked. +// * parent.isDir(). +// * name is not "." or "..". // // Postconditions: Caller must call fs.processDeferredDecRefs*. func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, child *Dentry) (*Dentry, error) { @@ -116,26 +120,22 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // Cached dentry exists, revalidate. if !child.inode.Valid(ctx) { delete(parent.children, name) - vfsObj.InvalidateDentry(&child.vfsd) - fs.deferDecRef(&child.vfsd) // Reference from Lookup. + vfsObj.InvalidateDentry(ctx, &child.vfsd) + fs.deferDecRef(child) // Reference from Lookup. child = nil } } if child == nil { - // Dentry isn't cached; it either doesn't exist or failed - // revalidation. Attempt to resolve it via Lookup. - // - // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return - // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes - // that all dentries in the filesystem are (kernfs.)Dentry and performs - // vfs.DentryImpl casts accordingly. - childVFSD, err := parent.inode.Lookup(ctx, name) + // Dentry isn't cached; it either doesn't exist or failed revalidation. + // Attempt to resolve it via Lookup. + c, err := parent.inode.Lookup(ctx, name) if err != nil { return nil, err } - // Reference on childVFSD dropped by a corresponding Valid. - child = childVFSD.Impl().(*Dentry) - parent.insertChildLocked(name, child) + // Reference on c (provided by Lookup) will be dropped when the dentry + // fails validation. + parent.InsertChildLocked(name, c) + child = c } return child, nil } @@ -148,20 +148,19 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // Preconditions: Filesystem.mu must be locked for at least reading. // // Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { - vfsd := rp.Start() +func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*Dentry, error) { + d := rp.Start().Impl().(*Dentry) for !rp.Done() { var err error - vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */) + d, err = fs.stepExistingLocked(ctx, rp, d, true /* mayFollowSymlinks */) if err != nil { - return nil, nil, err + return nil, err } } - d := vfsd.Impl().(*Dentry) if rp.MustBeDir() && !d.isDir() { - return nil, nil, syserror.ENOTDIR + return nil, syserror.ENOTDIR } - return vfsd, d.inode, nil + return d, nil } // walkParentDirLocked resolves all but the last path component of rp to an @@ -171,32 +170,34 @@ func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingP // walkParentDirLocked is loosely analogous to Linux's // fs/namei.c:path_parentat(). // -// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * !rp.Done(). // // Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { - vfsd := rp.Start() +func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*Dentry, error) { + d := rp.Start().Impl().(*Dentry) for !rp.Final() { var err error - vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */) + d, err = fs.stepExistingLocked(ctx, rp, d, true /* mayFollowSymlinks */) if err != nil { - return nil, nil, err + return nil, err } } - d := vfsd.Impl().(*Dentry) if !d.isDir() { - return nil, nil, syserror.ENOTDIR + return nil, syserror.ENOTDIR } - return vfsd, d.inode, nil + return d, nil } // checkCreateLocked checks that a file named rp.Component() may be created in // directory parentVFSD, then returns rp.Component(). // -// Preconditions: Filesystem.mu must be locked for at least reading. parentInode -// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * isDir(parentInode) == true. +func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) { + if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return "", err } pc := rp.Component() @@ -206,11 +207,10 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v if len(pc) > linux.NAME_MAX { return "", syserror.ENAMETOOLONG } - // FIXME(gvisor.dev/issue/1193): Data race due to not holding dirMu. - if _, ok := parentVFSD.Impl().(*Dentry).children[pc]; ok { + if _, ok := parent.children[pc]; ok { return "", syserror.EEXIST } - if parentVFSD.IsDead() { + if parent.VFSDentry().IsDead() { return "", syserror.ENOENT } return pc, nil @@ -219,8 +219,8 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v // checkDeleteLocked checks that the file represented by vfsd may be deleted. // // Preconditions: Filesystem.mu must be locked for at least reading. -func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error { - parent := vfsd.Impl().(*Dentry).parent +func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) error { + parent := d.parent if parent == nil { return syserror.EBUSY } @@ -234,7 +234,7 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Den } // Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release() { +func (fs *Filesystem) Release(context.Context) { } // Sync implements vfs.FilesystemImpl.Sync. @@ -246,35 +246,35 @@ func (fs *Filesystem) Sync(ctx context.Context) error { // AccessAt implements vfs.Filesystem.Impl.AccessAt. func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) if err != nil { return err } - return inode.CheckPermissions(ctx, creds, ats) + return d.inode.CheckPermissions(ctx, creds, ats) } // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) if err != nil { return nil, err } if opts.CheckSearchable { - d := vfsd.Impl().(*Dentry) if !d.isDir() { return nil, syserror.ENOTDIR } - if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { return nil, err } } + vfsd := d.VFSDentry() vfsd.IncRef() // Ownership transferred to caller. return vfsd, nil } @@ -282,14 +282,14 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op // GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() - vfsd, _, err := fs.walkParentDirLocked(ctx, rp) + d, err := fs.walkParentDirLocked(ctx, rp) if err != nil { return nil, err } - vfsd.IncRef() // Ownership transferred to caller. - return vfsd, nil + d.IncRef() // Ownership transferred to caller. + return d.VFSDentry(), nil } // LinkAt implements vfs.FilesystemImpl.LinkAt. @@ -299,12 +299,15 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + parent, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -321,11 +324,11 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EPERM } - childVFSD, err := parentInode.NewLink(ctx, pc, d.inode) + child, err := parent.inode.NewLink(ctx, pc, d.inode) if err != nil { return err } - parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, child) return nil } @@ -336,12 +339,15 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + parent, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -349,11 +355,14 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - childVFSD, err := parentInode.NewDir(ctx, pc, opts) + child, err := parent.inode.NewDir(ctx, pc, opts) if err != nil { - return err + if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + return err + } + child = newSyntheticDirectory(rp.Credentials(), opts.Mode) } - parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, child) return nil } @@ -364,12 +373,15 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + parent, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -377,11 +389,11 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - newVFSD, err := parentInode.NewNode(ctx, pc, opts) + newD, err := parent.inode.NewNode(ctx, pc, opts) if err != nil { return err } - parentVFSD.Impl().(*Dentry).InsertChild(pc, newVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, newD) return nil } @@ -397,24 +409,36 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // Do not create new file. if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() - defer fs.processDeferredDecRefs() - defer fs.mu.RUnlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() + fs.processDeferredDecRefs(ctx) return nil, err } - if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { + fs.mu.RUnlock() + fs.processDeferredDecRefs(ctx) return nil, err } - return inode.Open(ctx, rp, vfsd, opts) + d.inode.IncRef() + defer d.inode.DecRef(ctx) + fs.mu.RUnlock() + fs.processDeferredDecRefs(ctx) + return d.inode.Open(ctx, rp, d, opts) } // May create new file. mustCreate := opts.Flags&linux.O_EXCL != 0 - vfsd := rp.Start() - inode := vfsd.Impl().(*Dentry).inode + d := rp.Start().Impl().(*Dentry) fs.mu.Lock() - defer fs.mu.Unlock() + unlocked := false + unlock := func() { + if !unlocked { + fs.mu.Unlock() + unlocked = true + } + } + defer unlock() if rp.Done() { if rp.MustBeDir() { return nil, syserror.EISDIR @@ -422,19 +446,22 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { return nil, err } - return inode.Open(ctx, rp, vfsd, opts) + d.inode.IncRef() + defer d.inode.DecRef(ctx) + unlock() + return d.inode.Open(ctx, rp, d, opts) } afterTrailingSymlink: - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + parent, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return nil, err } // Check for search permission in the parent directory. - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { + if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { return nil, err } // Reject attempts to open directories with O_CREAT. @@ -449,10 +476,10 @@ afterTrailingSymlink: return nil, syserror.ENAMETOOLONG } // Determine whether or not we need to create a file. - childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD, false /* mayFollowSymlinks */) + child, err := fs.stepExistingLocked(ctx, rp, parent, false /* mayFollowSymlinks */) if err == syserror.ENOENT { // Already checked for searchability above; now check for writability. - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { + if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -460,13 +487,18 @@ afterTrailingSymlink: } defer rp.Mount().EndWrite() // Create and open the child. - childVFSD, err = parentInode.NewFile(ctx, pc, opts) + child, err := parent.inode.NewFile(ctx, pc, opts) if err != nil { return nil, err } - child := childVFSD.Impl().(*Dentry) - parentVFSD.Impl().(*Dentry).InsertChild(pc, child) - return child.inode.Open(ctx, rp, childVFSD, opts) + // FIXME(gvisor.dev/issue/1193): Race between checking existence with + // fs.stepExistingLocked and parent.InsertChild. If possible, we should hold + // dirMu from one to the other. + parent.InsertChild(pc, child) + child.inode.IncRef() + defer child.inode.DecRef(ctx) + unlock() + return child.inode.Open(ctx, rp, child, opts) } if err != nil { return nil, err @@ -475,7 +507,6 @@ afterTrailingSymlink: if mustCreate { return nil, syserror.EEXIST } - child := childVFSD.Impl().(*Dentry) if rp.ShouldFollowSymlink() && child.isSymlink() { targetVD, targetPathname, err := child.inode.Getlink(ctx, rp.Mount()) if err != nil { @@ -483,7 +514,7 @@ afterTrailingSymlink: } if targetVD.Ok() { err := rp.HandleJump(targetVD) - targetVD.DecRef() + targetVD.DecRef(ctx) if err != nil { return nil, err } @@ -499,22 +530,25 @@ afterTrailingSymlink: if err := child.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { return nil, err } - return child.inode.Open(ctx, rp, &child.vfsd, opts) + child.inode.IncRef() + defer child.inode.DecRef(ctx) + unlock() + return child.inode.Open(ctx, rp, child, opts) } // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { fs.mu.RLock() - d, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return "", err } - if !d.Impl().(*Dentry).isSymlink() { + if !d.isSymlink() { return "", syserror.EINVAL } - return inode.Readlink(ctx) + return d.inode.Readlink(ctx, rp.Mount()) } // RenameAt implements vfs.FilesystemImpl.RenameAt. @@ -526,16 +560,15 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 fs.mu.Lock() - defer fs.processDeferredDecRefsLocked() + defer fs.processDeferredDecRefsLocked(ctx) defer fs.mu.Unlock() // Resolve the destination directory first to verify that it's on this // Mount. - dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + dstDir, err := fs.walkParentDirLocked(ctx, rp) if err != nil { return err } - dstDir := dstDirVFSD.Impl().(*Dentry) mnt := rp.Mount() if mnt != oldParentVD.Mount() { return syserror.EXDEV @@ -553,16 +586,15 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } - srcVFSD := &src.vfsd // Can we remove the src dentry? - if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil { + if err := checkDeleteLocked(ctx, rp, src); err != nil { return err } // Can we create the dst dentry? var dst *Dentry - pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode) + pc, err := checkCreateLocked(ctx, rp, dstDir) switch err { case nil: // Ok, continue with rename as replacement. @@ -573,18 +605,18 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } dst = dstDir.children[pc] if dst == nil { - panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD)) + panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDir)) } default: return err } var dstVFSD *vfs.Dentry if dst != nil { - dstVFSD = &dst.vfsd + dstVFSD = dst.VFSDentry() } mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) virtfs := rp.VirtualFilesystem() // We can't deadlock here due to lock ordering because we're protected from @@ -596,17 +628,18 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa defer dstDir.dirMu.Unlock() } + srcVFSD := src.VFSDentry() if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil { return err } - replaced, err := srcDir.inode.Rename(ctx, src.name, pc, srcVFSD, dstDirVFSD) + replaced, err := srcDir.inode.Rename(ctx, src.name, pc, src, dstDir) if err != nil { virtfs.AbortRenameDentry(srcVFSD, dstVFSD) return err } delete(srcDir.children, src.name) if srcDir != dstDir { - fs.deferDecRef(srcDirVFSD) + fs.deferDecRef(srcDir) dstDir.IncRef() } src.parent = dstDir @@ -615,7 +648,11 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa dstDir.children = make(map[string]*Dentry) } dstDir.children[pc] = src - virtfs.CommitRenameReplaceDentry(srcVFSD, replaced) + var replaceVFSD *vfs.Dentry + if replaced != nil { + replaceVFSD = replaced.VFSDentry() + } + virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD) return nil } @@ -623,8 +660,9 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + + d, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -632,14 +670,13 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } defer rp.Mount().EndWrite() - if err := checkDeleteLocked(ctx, rp, vfsd); err != nil { + if err := checkDeleteLocked(ctx, rp, d); err != nil { return err } - d := vfsd.Impl().(*Dentry) if !d.isDir() { return syserror.ENOTDIR } - if inode.HasChildren() { + if d.inode.HasChildren() { return syserror.ENOTEMPTY } virtfs := rp.VirtualFilesystem() @@ -648,56 +685,57 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error defer parentDentry.dirMu.Unlock() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) + vfsd := d.VFSDentry() if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { + + if err := parentDentry.inode.RmDir(ctx, d.name, d); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } - virtfs.CommitDeleteDentry(vfsd) + virtfs.CommitDeleteDentry(ctx, vfsd) return nil } // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return err } if opts.Stat.Mask == 0 { return nil } - return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts) + return d.inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts) } // StatAt implements vfs.FilesystemImpl.StatAt. func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statx{}, err } - return inode.Stat(fs.VFSFilesystem(), opts) + return d.inode.Stat(ctx, fs.VFSFilesystem(), opts) } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statfs{}, err } - // TODO(gvisor.dev/issue/1193): actually implement statfs. - return linux.Statfs{}, syserror.ENOSYS + return d.inode.StatFS(ctx, fs.VFSFilesystem()) } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. @@ -707,12 +745,15 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + parent, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -720,11 +761,11 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ return err } defer rp.Mount().EndWrite() - childVFSD, err := parentInode.NewSymlink(ctx, pc, target) + child, err := parent.inode.NewSymlink(ctx, pc, target) if err != nil { return err } - parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, child) return nil } @@ -732,8 +773,9 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - vfsd, _, err := fs.walkExistingLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + + d, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -741,10 +783,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } defer rp.Mount().EndWrite() - if err := checkDeleteLocked(ctx, rp, vfsd); err != nil { + if err := checkDeleteLocked(ctx, rp, d); err != nil { return err } - d := vfsd.Impl().(*Dentry) if d.isDir() { return syserror.EISDIR } @@ -753,39 +794,40 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error parentDentry.dirMu.Lock() defer parentDentry.dirMu.Unlock() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) + vfsd := d.VFSDentry() if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { + if err := parentDentry.inode.Unlink(ctx, d.name, d); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } - virtfs.CommitDeleteDentry(vfsd) + virtfs.CommitDeleteDentry(ctx, vfsd) return nil } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return nil, err } - if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return nil, err } @@ -793,12 +835,12 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, syserror.ENOTSUP } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return "", err } @@ -806,12 +848,12 @@ func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", syserror.ENOTSUP } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return err } @@ -819,12 +861,12 @@ func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return syserror.ENOTSUP } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *Filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 4cb885d87..49210e748 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -32,6 +31,8 @@ import ( // count for inodes, performing no extra actions when references are obtained or // released. This is suitable for simple file inodes that don't reference any // resources. +// +// +stateify savable type InodeNoopRefCount struct { } @@ -40,7 +41,7 @@ func (InodeNoopRefCount) IncRef() { } // DecRef implements Inode.DecRef. -func (InodeNoopRefCount) DecRef() { +func (InodeNoopRefCount) DecRef(context.Context) { } // TryIncRef implements Inode.TryIncRef. @@ -48,37 +49,35 @@ func (InodeNoopRefCount) TryIncRef() bool { return true } -// Destroy implements Inode.Destroy. -func (InodeNoopRefCount) Destroy() { -} - // InodeDirectoryNoNewChildren partially implements the Inode interface. // InodeDirectoryNoNewChildren represents a directory inode which does not // support creation of new children. +// +// +stateify savable type InodeDirectoryNoNewChildren struct{} // NewFile implements Inode.NewFile. -func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) { return nil, syserror.EPERM } // NewDir implements Inode.NewDir. -func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) { return nil, syserror.EPERM } // NewLink implements Inode.NewLink. -func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*Dentry, error) { return nil, syserror.EPERM } // NewSymlink implements Inode.NewSymlink. -func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*Dentry, error) { return nil, syserror.EPERM } // NewNode implements Inode.NewNode. -func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) { return nil, syserror.EPERM } @@ -86,6 +85,8 @@ func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOpt // inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not // represent directories can embed this to provide no-op implementations for // directory-related functions. +// +// +stateify savable type InodeNotDirectory struct { } @@ -95,47 +96,47 @@ func (InodeNotDirectory) HasChildren() bool { } // NewFile implements Inode.NewFile. -func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) { panic("NewFile called on non-directory inode") } // NewDir implements Inode.NewDir. -func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) { panic("NewDir called on non-directory inode") } // NewLink implements Inode.NewLinkink. -func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*Dentry, error) { panic("NewLink called on non-directory inode") } // NewSymlink implements Inode.NewSymlink. -func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*Dentry, error) { panic("NewSymlink called on non-directory inode") } // NewNode implements Inode.NewNode. -func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) { panic("NewNode called on non-directory inode") } // Unlink implements Inode.Unlink. -func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { +func (InodeNotDirectory) Unlink(context.Context, string, *Dentry) error { panic("Unlink called on non-directory inode") } // RmDir implements Inode.RmDir. -func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { +func (InodeNotDirectory) RmDir(context.Context, string, *Dentry) error { panic("RmDir called on non-directory inode") } // Rename implements Inode.Rename. -func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { +func (InodeNotDirectory) Rename(context.Context, string, string, *Dentry, *Dentry) (*Dentry, error) { panic("Rename called on non-directory inode") } // Lookup implements Inode.Lookup. -func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*Dentry, error) { panic("Lookup called on non-directory inode") } @@ -154,10 +155,12 @@ func (InodeNotDirectory) Valid(context.Context) bool { // dymanic entries (i.e. entries that are not "hashed" into the // vfs.Dentry.children) can embed this to provide no-op implementations for // functions related to dynamic entries. +// +// +stateify savable type InodeNoDynamicLookup struct{} // Lookup implements Inode.Lookup. -func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*Dentry, error) { return nil, syserror.ENOENT } @@ -174,10 +177,12 @@ func (InodeNoDynamicLookup) Valid(ctx context.Context) bool { // InodeNotSymlink partially implements the Inode interface, specifically the // inodeSymlink sub interface. All inodes that are not symlinks may embed this // to return the appropriate errors from symlink-related functions. +// +// +stateify savable type InodeNotSymlink struct{} // Readlink implements Inode.Readlink. -func (InodeNotSymlink) Readlink(context.Context) (string, error) { +func (InodeNotSymlink) Readlink(context.Context, *vfs.Mount) (string, error) { return "", syserror.EINVAL } @@ -191,6 +196,8 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // inode attributes. // // Must be initialized by Init prior to first use. +// +// +stateify savable type InodeAttrs struct { devMajor uint32 devMinor uint32 @@ -243,7 +250,7 @@ func (a *InodeAttrs) Mode() linux.FileMode { // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. -func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { +func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK stat.DevMajor = a.devMajor @@ -261,13 +268,30 @@ func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return a.SetInodeStat(ctx, fs, creds, opts) +} + +// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. +// This function can be used by other kernfs-based filesystem implementation to +// sets the unexported attributes into kernfs.InodeAttrs. +func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { + + // Note that not all fields are modifiable. For example, the file type and + // inode numbers are immutable after node creation. Setting the size is often + // allowed by kernfs files but does not do anything. If some other behavior is + // needed, the embedder should consider extending SetStat. + // + // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { + if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { + return syserror.EISDIR + } + if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } @@ -289,13 +313,6 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut atomic.StoreUint32(&a.gid, stat.GID) } - // Note that not all fields are modifiable. For example, the file type and - // inode numbers are immutable after node creation. - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - // Also, STATX_SIZE will need some special handling, because read-only static - // files should return EIO for truncate operations. - return nil } @@ -325,13 +342,16 @@ func (a *InodeAttrs) DecLinks() { } } +// +stateify savable type slot struct { Name string - Dentry *vfs.Dentry + Dentry *Dentry slotEntry } // OrderedChildrenOptions contains initialization options for OrderedChildren. +// +// +stateify savable type OrderedChildrenOptions struct { // Writable indicates whether vfs.FilesystemImpl methods implemented by // OrderedChildren may modify the tracked children. This applies to @@ -347,14 +367,14 @@ type OrderedChildrenOptions struct { // directories. // // Must be initialize with Init before first use. +// +// +stateify savable type OrderedChildren struct { - refs.AtomicRefCount - // Can children be modified by user syscalls? It set to false, interface // methods that would modify the children return EPERM. Immutable. writable bool - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` order slotList set map[string]*slot } @@ -365,12 +385,9 @@ func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { o.set = make(map[string]*slot) } -// DecRef implements Inode.DecRef. -func (o *OrderedChildren) DecRef() { - o.AtomicRefCount.DecRefWithDestructor(o.Destroy) -} - -// Destroy cleans up resources referenced by this OrderedChildren. +// Destroy clears the children stored in o. It should be called by structs +// embedding OrderedChildren upon destruction, i.e. when their reference count +// reaches zero. func (o *OrderedChildren) Destroy() { o.mu.Lock() defer o.mu.Unlock() @@ -390,7 +407,7 @@ func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint3 if child.isDir() { links++ } - if err := o.Insert(name, child.VFSDentry()); err != nil { + if err := o.Insert(name, child); err != nil { panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d)) } d.InsertChild(name, child) @@ -407,7 +424,7 @@ func (o *OrderedChildren) HasChildren() bool { // Insert inserts child into o. This ignores the writability of o, as this is // not part of the vfs.FilesystemImpl interface, and is a lower-level operation. -func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error { +func (o *OrderedChildren) Insert(name string, child *Dentry) error { o.mu.Lock() defer o.mu.Unlock() if _, ok := o.set[name]; ok { @@ -431,10 +448,10 @@ func (o *OrderedChildren) removeLocked(name string) { } // Precondition: caller must hold o.mu for writing. -func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry { +func (o *OrderedChildren) replaceChildLocked(name string, new *Dentry) *Dentry { if s, ok := o.set[name]; ok { // Existing slot with given name, simply replace the dentry. - var old *vfs.Dentry + var old *Dentry old, s.Dentry = s.Dentry, new return old } @@ -450,7 +467,7 @@ func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs. } // Precondition: caller must hold o.mu for reading or writing. -func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error { +func (o *OrderedChildren) checkExistingLocked(name string, child *Dentry) error { s, ok := o.set[name] if !ok { return syserror.ENOENT @@ -462,7 +479,7 @@ func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) er } // Unlink implements Inode.Unlink. -func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { +func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry) error { if !o.writable { return syserror.EPERM } @@ -478,12 +495,13 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.De } // Rmdir implements Inode.Rmdir. -func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { +func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *Dentry) error { // We're not responsible for checking that child is a directory, that it's // empty, or updating any link counts; so this is the same as unlink. return o.Unlink(ctx, name, child) } +// +stateify savable type renameAcrossDifferentImplementationsError struct{} func (renameAcrossDifferentImplementationsError) Error() string { @@ -499,8 +517,8 @@ func (renameAcrossDifferentImplementationsError) Error() string { // that will support Rename. // // Postcondition: reference on any replaced dentry transferred to caller. -func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) { - dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren) +func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (*Dentry, error) { + dst, ok := dstDir.inode.(interface{}).(*OrderedChildren) if !ok { return nil, renameAcrossDifferentImplementationsError{} } @@ -542,12 +560,14 @@ func (o *OrderedChildren) nthLocked(i int64) *slot { } // InodeSymlink partially implements Inode interface for symlinks. +// +// +stateify savable type InodeSymlink struct { InodeNotDirectory } // Open implements Inode.Open. -func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return nil, syserror.ELOOP } @@ -556,21 +576,25 @@ func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D // // +stateify savable type StaticDirectory struct { - InodeNotSymlink - InodeDirectoryNoNewChildren InodeAttrs + InodeDirectoryNoNewChildren InodeNoDynamicLookup + InodeNoStatFS + InodeNotSymlink OrderedChildren + StaticDirectoryRefs - locks vfs.FileLocks + locks vfs.FileLocks + fdOpts GenericDirectoryFDOptions } var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry { +func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry, fdOpts GenericDirectoryFDOptions) *Dentry { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm) + inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) + inode.EnableLeakCheck() dentry := &Dentry{} dentry.Init(inode) @@ -583,31 +607,50 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } + s.fdOpts = fdOpts s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } -// Open implements kernfs.Inode. -func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts) +// Open implements kernfs.Inode.Open. +func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := NewGenericDirectoryFD(rp.Mount(), d, &s.OrderedChildren, &s.locks, &opts, s.fdOpts) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } +// DecRef implements kernfs.Inode.DecRef. +func (s *StaticDirectory) DecRef(context.Context) { + s.StaticDirectoryRefs.DecRef(s.Destroy) +} + // AlwaysValid partially implements kernfs.inodeDynamicLookup. +// +// +stateify savable type AlwaysValid struct{} -// Valid implements kernfs.inodeDynamicLookup. +// Valid implements kernfs.inodeDynamicLookup.Valid. func (*AlwaysValid) Valid(context.Context) bool { return true } + +// InodeNoStatFS partially implements the Inode interface, where the client +// filesystem doesn't support statfs(2). +// +// +stateify savable +type InodeNoStatFS struct{} + +// StatFS implements Inode.StatFS. +func (*InodeNoStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return linux.Statfs{}, syserror.ENOSYS +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 596de1edf..6d3d79333 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -29,7 +29,7 @@ // // Reference Model: // -// Kernfs dentries represents named pointers to inodes. Dentries and inode have +// Kernfs dentries represents named pointers to inodes. Dentries and inodes have // independent lifetimes and reference counts. A child dentry unconditionally // holds a reference on its parent directory's dentry. A dentry also holds a // reference on the inode it points to. Multiple dentries can point to the same @@ -57,24 +57,26 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" ) // Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory // filesystem. Concrete implementations are expected to embed this in their own // Filesystem type. +// +// +stateify savable type Filesystem struct { vfsfs vfs.Filesystem - droppedDentriesMu sync.Mutex + droppedDentriesMu sync.Mutex `state:"nosave"` // droppedDentries is a list of dentries waiting to be DecRef()ed. This is // used to defer dentry destruction until mu can be acquired for // writing. Protected by droppedDentriesMu. - droppedDentries []*vfs.Dentry + droppedDentries []*Dentry // mu synchronizes the lifetime of Dentries on this filesystem. Holding it // for reading guarantees continued existence of any resolved dentries, but @@ -97,7 +99,7 @@ type Filesystem struct { // defer fs.mu.RUnlock() // ... // fs.deferDecRef(dentry) - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. @@ -108,7 +110,7 @@ type Filesystem struct { // processDeferredDecRefs{,Locked}. See comment on Filesystem.mu. // // Precondition: d must not already be pending destruction. -func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { +func (fs *Filesystem) deferDecRef(d *Dentry) { fs.droppedDentriesMu.Lock() fs.droppedDentries = append(fs.droppedDentries, d) fs.droppedDentriesMu.Unlock() @@ -116,17 +118,17 @@ func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { // processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the // droppedDentries list. See comment on Filesystem.mu. -func (fs *Filesystem) processDeferredDecRefs() { +func (fs *Filesystem) processDeferredDecRefs(ctx context.Context) { fs.mu.Lock() - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) fs.mu.Unlock() } // Precondition: fs.mu must be held for writing. -func (fs *Filesystem) processDeferredDecRefsLocked() { +func (fs *Filesystem) processDeferredDecRefsLocked(ctx context.Context) { fs.droppedDentriesMu.Lock() for _, d := range fs.droppedDentries { - d.DecRef() + d.DecRef(ctx) } fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse. fs.droppedDentriesMu.Unlock() @@ -160,10 +162,12 @@ const ( // to, and child dentries hold a reference on their parent. // // Must be initialized by Init prior to first use. +// +// +stateify savable type Dentry struct { - vfsd vfs.Dentry + DentryRefs - refs.AtomicRefCount + vfsd vfs.Dentry // flags caches useful information about the dentry from the inode. See the // dflags* consts above. Must be accessed by atomic ops. @@ -173,7 +177,11 @@ type Dentry struct { name string // dirMu protects children and the names of child Dentries. - dirMu sync.Mutex + // + // Note that holding fs.mu for writing is not sufficient; + // revalidateChildLocked(), which is a very hot path, may modify children with + // fs.mu acquired for reading only. + dirMu sync.Mutex `state:"nosave"` children map[string]*Dentry inode Inode @@ -194,6 +202,7 @@ func (d *Dentry) Init(inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } + d.EnableLeakCheck() } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -212,17 +221,15 @@ func (d *Dentry) isSymlink() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef() { - d.AtomicRefCount.DecRefWithDestructor(d.destroy) -} - -// Precondition: Dentry must be removed from VFS' dentry cache. -func (d *Dentry) destroy() { - d.inode.DecRef() // IncRef from Init. - d.inode = nil - if d.parent != nil { - d.parent.DecRef() // IncRef from Dentry.InsertChild. - } +func (d *Dentry) DecRef(ctx context.Context) { + // Before the destructor is called, Dentry must be removed from VFS' dentry cache. + d.DentryRefs.DecRef(func() { + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + if d.parent != nil { + d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild. + } + }) } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -230,7 +237,7 @@ func (d *Dentry) destroy() { // Although Linux technically supports inotify on pseudo filesystems (inotify // is implemented at the vfs layer), it is not particularly useful. It is left // unimplemented until someone actually needs it. -func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} +func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. func (d *Dentry) Watches() *vfs.Watches { @@ -238,27 +245,28 @@ func (d *Dentry) Watches() *vfs.Watches { } // OnZeroWatches implements vfs.Dentry.OnZeroWatches. -func (d *Dentry) OnZeroWatches() {} +func (d *Dentry) OnZeroWatches(context.Context) {} // InsertChild inserts child into the vfs dentry cache with the given name under -// this dentry. This does not update the directory inode, so calling this on -// its own isn't sufficient to insert a child into a directory. InsertChild -// updates the link count on d if required. +// this dentry. This does not update the directory inode, so calling this on its +// own isn't sufficient to insert a child into a directory. // // Precondition: d must represent a directory inode. func (d *Dentry) InsertChild(name string, child *Dentry) { d.dirMu.Lock() - d.insertChildLocked(name, child) + d.InsertChildLocked(name, child) d.dirMu.Unlock() } -// insertChildLocked is equivalent to InsertChild, with additional +// InsertChildLocked is equivalent to InsertChild, with additional // preconditions. // -// Precondition: d.dirMu must be locked. -func (d *Dentry) insertChildLocked(name string, child *Dentry) { +// Preconditions: +// * d must represent a directory inode. +// * d.dirMu must be locked. +func (d *Dentry) InsertChildLocked(name string, child *Dentry) { if !d.isDir() { - panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) + panic(fmt.Sprintf("InsertChildLocked called on non-directory Dentry: %+v.", d)) } d.IncRef() // DecRef in child's Dentry.destroy. child.parent = d @@ -269,6 +277,36 @@ func (d *Dentry) insertChildLocked(name string, child *Dentry) { d.children[name] = child } +// RemoveChild removes child from the vfs dentry cache. This does not update the +// directory inode or modify the inode to be unlinked. So calling this on its own +// isn't sufficient to remove a child from a directory. +// +// Precondition: d must represent a directory inode. +func (d *Dentry) RemoveChild(name string, child *Dentry) error { + d.dirMu.Lock() + defer d.dirMu.Unlock() + return d.RemoveChildLocked(name, child) +} + +// RemoveChildLocked is equivalent to RemoveChild, with additional +// preconditions. +// +// Precondition: d.dirMu must be locked. +func (d *Dentry) RemoveChildLocked(name string, child *Dentry) error { + if !d.isDir() { + panic(fmt.Sprintf("RemoveChild called on non-directory Dentry: %+v.", d)) + } + c, ok := d.children[name] + if !ok { + return syserror.ENOENT + } + if c != child { + panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! Child: %+v, vfs: %+v", c, child)) + } + delete(d.children, name) + return nil +} + // Inode returns the dentry's inode. func (d *Dentry) Inode() Inode { return d.inode @@ -289,7 +327,6 @@ func (d *Dentry) Inode() Inode { // // - Checking that dentries passed to methods are of the appropriate file type. // - Checking permissions. -// - Updating link and reference counts. // // Specific responsibilities of implementations are documented below. type Inode interface { @@ -299,7 +336,8 @@ type Inode interface { inodeRefs // Methods related to node metadata. A generic implementation is provided by - // InodeAttrs. + // InodeAttrs. Note that a concrete filesystem using kernfs is responsible for + // managing link counts. inodeMetadata // Method for inodes that represent symlink. InodeNotSymlink provides a @@ -317,21 +355,22 @@ type Inode interface { // Open creates a file description for the filesystem object represented by // this inode. The returned file description should hold a reference on the - // inode for its lifetime. + // dentry for its lifetime. // // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing // the inode on which Open() is being called. - Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) + Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) + + // StatFS returns filesystem statistics for the client filesystem. This + // corresponds to vfs.FilesystemImpl.StatFSAt. If the client filesystem + // doesn't support statfs(2), this should return ENOSYS. + StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) } type inodeRefs interface { IncRef() - DecRef() + DecRef(ctx context.Context) TryIncRef() bool - // Destroy is called when the inode reaches zero references. Destroy release - // all resources (references) on objects referenced by the inode, including - // any child dentries. - Destroy() } type inodeMetadata interface { @@ -346,7 +385,7 @@ type inodeMetadata interface { // Stat returns the metadata for this inode. This corresponds to // vfs.FilesystemImpl.StatAt. - Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) + Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) // SetStat updates the metadata for this inode. This corresponds to // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking @@ -370,30 +409,30 @@ type inodeDirectory interface { HasChildren() bool // NewFile creates a new regular file inode. - NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) + NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) // NewDir creates a new directory inode. - NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) + NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) // NewLink creates a new hardlink to a specified inode in this // directory. Implementations should create a new kernfs Dentry pointing to // target, and update target's link count. - NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) + NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) // NewSymlink creates a new symbolic link inode. - NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) + NewSymlink(ctx context.Context, name, target string) (*Dentry, error) // NewNode creates a new filesystem node for a mknod syscall. - NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) + NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) // Unlink removes a child dentry from this directory inode. - Unlink(ctx context.Context, name string, child *vfs.Dentry) error + Unlink(ctx context.Context, name string, child *Dentry) error // RmDir removes an empty child directory from this directory // inode. Implementations must update the parent directory's link count, // if required. Implementations are not responsible for checking that child // is a directory, checking for an empty directory. - RmDir(ctx context.Context, name string, child *vfs.Dentry) error + RmDir(ctx context.Context, name string, child *Dentry) error // Rename is called on the source directory containing an inode being // renamed. child should point to the resolved child in the source @@ -401,7 +440,7 @@ type inodeDirectory interface { // should return the replaced dentry or nil otherwise. // // Precondition: Caller must serialize concurrent calls to Rename. - Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error) + Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (replaced *Dentry, err error) } type inodeDynamicLookup interface { @@ -419,14 +458,14 @@ type inodeDynamicLookup interface { // // Lookup returns the child with an extra reference and the caller owns this // reference. - Lookup(ctx context.Context, name string) (*vfs.Dentry, error) + Lookup(ctx context.Context, name string) (*Dentry, error) // Valid should return true if this inode is still valid, or needs to // be resolved again by a call to Lookup. Valid(ctx context.Context) bool // IterDirents is used to iterate over dynamically created entries. It invokes - // cb on each entry in the directory represented by the FileDescription. + // cb on each entry in the directory represented by the Inode. // 'offset' is the offset for the entire IterDirents call, which may include // results from the caller (e.g. "." and ".."). 'relOffset' is the offset // inside the entries returned by this IterDirents invocation. In other words, @@ -438,7 +477,7 @@ type inodeDynamicLookup interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. - Readlink(ctx context.Context) (string, error) + Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path // resolution: diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index dc407eb1d..e413242dc 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -46,13 +46,13 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System { ctx := contexttest.Context(t) creds := auth.CredentialsFromContext(ctx) v := &vfs.VirtualFilesystem{} - if err := v.Init(); err != nil { + if err := v.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) + mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.MountOptions{}) if err != nil { t.Fatalf("Failed to create testfs root mount: %v", err) } @@ -96,10 +96,12 @@ func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.S } type readonlyDir struct { + readonlyDirRefs attrs - kernfs.InodeNotSymlink - kernfs.InodeNoDynamicLookup kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNoDynamicLookup + kernfs.InodeNoStatFS + kernfs.InodeNotSymlink kernfs.OrderedChildren locks vfs.FileLocks @@ -111,6 +113,7 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod dir := &readonlyDir{} dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + dir.EnableLeakCheck() dir.dentry.Init(dir) dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) @@ -118,19 +121,27 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod return &dir.dentry } -func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) +func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +func (d *readonlyDir) DecRef(context.Context) { + d.readonlyDirRefs.DecRef(d.Destroy) +} + type dir struct { + dirRefs attrs - kernfs.InodeNotSymlink kernfs.InodeNoDynamicLookup + kernfs.InodeNotSymlink kernfs.OrderedChildren + kernfs.InodeNoStatFS locks vfs.FileLocks @@ -143,6 +154,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte dir.fs = fs dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) + dir.EnableLeakCheck() dir.dentry.Init(dir) dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) @@ -150,46 +162,50 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte return &dir.dentry } -func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { +func (d *dir) DecRef(context.Context) { + d.dirRefs.DecRef(d.Destroy) +} + +func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) { creds := auth.CredentialsFromContext(ctx) dir := d.fs.newDir(creds, opts.Mode, nil) - dirVFSD := dir.VFSDentry() - if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { - dir.DecRef() + if err := d.OrderedChildren.Insert(name, dir); err != nil { + dir.DecRef(ctx) return nil, err } d.IncLinks(1) - return dirVFSD, nil + return dir, nil } -func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { +func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) { creds := auth.CredentialsFromContext(ctx) f := d.fs.newFile(creds, "") - fVFSD := f.VFSDentry() - if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { - f.DecRef() + if err := d.OrderedChildren.Insert(name, f); err != nil { + f.DecRef(ctx) return nil, err } - return fVFSD, nil + return f, nil } -func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) { +func (*dir) NewLink(context.Context, string, kernfs.Inode) (*kernfs.Dentry, error) { return nil, syserror.EPERM } -func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (*dir) NewSymlink(context.Context, string, string) (*kernfs.Dentry, error) { return nil, syserror.EPERM } -func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*kernfs.Dentry, error) { return nil, syserror.EPERM } @@ -213,7 +229,7 @@ func TestBasic(t *testing.T) { }) }) defer sys.Destroy() - sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef() + sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef(sys.Ctx) } func TestMkdirGetDentry(t *testing.T) { @@ -228,7 +244,7 @@ func TestMkdirGetDentry(t *testing.T) { if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) } - sys.GetDentryOrDie(pop).DecRef() + sys.GetDentryOrDie(pop).DecRef(sys.Ctx) } func TestReadStaticFile(t *testing.T) { @@ -246,7 +262,7 @@ func TestReadStaticFile(t *testing.T) { if err != nil { t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) } - defer fd.DecRef() + defer fd.DecRef(sys.Ctx) content, err := sys.ReadToEnd(fd) if err != nil { @@ -273,7 +289,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } // Close the file. The file should persist. - fd.DecRef() + fd.DecRef(sys.Ctx) fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ Flags: linux.O_RDONLY, @@ -281,7 +297,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) { if err != nil { t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) } - fd.DecRef() + fd.DecRef(sys.Ctx) } func TestDirFDReadWrite(t *testing.T) { @@ -297,7 +313,7 @@ func TestDirFDReadWrite(t *testing.T) { if err != nil { t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) } - defer fd.DecRef() + defer fd.DecRef(sys.Ctx) // Read/Write should fail for directory FDs. if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 2ab3f53fd..58a93eaac 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -24,10 +24,13 @@ import ( // StaticSymlink provides an Inode implementation for symlinks that point to // a immutable target. +// +// +stateify savable type StaticSymlink struct { InodeAttrs InodeNoopRefCount InodeSymlink + InodeNoStatFS target string } @@ -50,8 +53,8 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } -// Readlink implements Inode. -func (s *StaticSymlink) Readlink(_ context.Context) (string, error) { +// Readlink implements Inode.Readlink. +func (s *StaticSymlink) Readlink(_ context.Context, _ *vfs.Mount) (string, error) { return s.target, nil } diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go new file mode 100644 index 000000000..ea7f073eb --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -0,0 +1,102 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// syntheticDirectory implements kernfs.Inode for a directory created by +// MkdirAt(ForSyntheticMountpoint=true). +// +// +stateify savable +type syntheticDirectory struct { + InodeAttrs + InodeNoStatFS + InodeNoopRefCount + InodeNoDynamicLookup + InodeNotSymlink + OrderedChildren + + locks vfs.FileLocks +} + +var _ Inode = (*syntheticDirectory)(nil) + +func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) *Dentry { + inode := &syntheticDirectory{} + inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + d := &Dentry{} + d.Init(inode) + return d +} + +func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) + } + dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.OrderedChildren.Init(OrderedChildrenOptions{ + Writable: true, + }) +} + +// Open implements Inode.Open. +func (dir *syntheticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := NewGenericDirectoryFD(rp.Mount(), d, &dir.OrderedChildren, &dir.locks, &opts, GenericDirectoryFDOptions{}) + if err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// NewFile implements Inode.NewFile. +func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) { + return nil, syserror.EPERM +} + +// NewDir implements Inode.NewDir. +func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) { + if !opts.ForSyntheticMountpoint { + return nil, syserror.EPERM + } + subdird := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + if err := dir.OrderedChildren.Insert(name, subdird); err != nil { + subdird.DecRef(ctx) + return nil, err + } + return subdird, nil +} + +// NewLink implements Inode.NewLink. +func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) { + return nil, syserror.EPERM +} + +// NewSymlink implements Inode.NewSymlink. +func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (*Dentry, error) { + return nil, syserror.EPERM +} + +// NewNode implements Inode.NewNode. +func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) { + return nil, syserror.EPERM +} diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 8f8dcfafe..73b126669 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -22,6 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -40,6 +42,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return nil } + // Attach our credentials to the context, as some VFS operations use + // credentials from context rather an take an explicit creds parameter. + ctx = auth.ContextWithCredentials(ctx, d.fs.creds) + ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT switch ftype { case linux.S_IFREG, linux.S_IFDIR, linux.S_IFLNK, linux.S_IFBLK, linux.S_IFCHR: @@ -76,6 +82,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Start: d.parent.upperVD, Path: fspath.Parse(d.name), } + // Used during copy-up of memory-mapped regular files. + var mmapOpts *memmap.MMapOpts cleanupUndoCopyUp := func() { var err error if ftype == linux.S_IFDIR { @@ -84,7 +92,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop) } if err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)) + } + if d.upperVD.Ok() { + d.upperVD.DecRef(ctx) + d.upperVD = vfs.VirtualDentry{} } } switch ftype { @@ -98,7 +110,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { if err != nil { return err } - defer oldFD.DecRef() + defer oldFD.DecRef(ctx) newFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &newpop, &vfs.OpenOptions{ Flags: linux.O_WRONLY | linux.O_CREAT | linux.O_EXCL, Mode: linux.FileMode(d.mode &^ linux.S_IFMT), @@ -106,7 +118,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { if err != nil { return err } - defer newFD.DecRef() + defer newFD.DecRef(ctx) bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size for { readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{}) @@ -127,6 +139,25 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { break } } + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if d.wrappedMappable != nil { + // We may have memory mappings of the file on the lower layer. + // Switch to mapping the file on the upper layer instead. + mmapOpts = &memmap.MMapOpts{ + Perms: usermem.ReadWrite, + MaxPerms: usermem.ReadWrite, + } + if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil { + cleanupUndoCopyUp() + return err + } + if mmapOpts.MappingIdentity != nil { + mmapOpts.MappingIdentity.DecRef(ctx) + } + // Don't actually switch Mappables until the end of copy-up; see + // below for why. + } if err := newFD.SetStat(ctx, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_UID | linux.STATX_GID, @@ -229,7 +260,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { panic(fmt.Sprintf("unexpected file type %o", ftype)) } - // TODO(gvisor.dev/issue/1199): copy up xattrs + if err := d.copyXattrsLocked(ctx); err != nil { + cleanupUndoCopyUp() + return err + } // Update the dentry's device and inode numbers (except for directories, // for which these remain overlay-assigned). @@ -241,14 +275,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Mask: linux.STATX_INO, }) if err != nil { - d.upperVD.DecRef() - d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return err } if upperStat.Mask&linux.STATX_INO == 0 { - d.upperVD.DecRef() - d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return syserror.EREMOTE } @@ -257,6 +287,135 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { atomic.StoreUint64(&d.ino, upperStat.Ino) } + if mmapOpts != nil && mmapOpts.Mappable != nil { + // Note that if mmapOpts != nil, then d.mapsMu is locked for writing + // (from the S_IFREG path above). + + // Propagate mappings of d to the new Mappable. Remember which mappings + // we added so we can remove them on failure. + upperMappable := mmapOpts.Mappable + allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange) + for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + added := make(memmap.MappingsOfRange) + for m := range seg.Value() { + if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil { + for m := range added { + upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) + } + for mr, mappings := range allAdded { + for m := range mappings { + upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable) + } + } + return err + } + added[m] = struct{}{} + } + allAdded[seg.Range()] = added + } + + // Switch to the new Mappable. We do this at the end of copy-up + // because: + // + // - We need to switch Mappables (by changing d.wrappedMappable) before + // invalidating Translations from the old Mappable (to pick up + // Translations from the new one). + // + // - We need to lock d.dataMu while changing d.wrappedMappable, but + // must invalidate Translations with d.dataMu unlocked (due to lock + // ordering). + // + // - Consequently, once we unlock d.dataMu, other threads may + // immediately observe the new (copied-up) Mappable, which we want to + // delay until copy-up is guaranteed to succeed. + d.dataMu.Lock() + lowerMappable := d.wrappedMappable + d.wrappedMappable = upperMappable + d.dataMu.Unlock() + d.lowerMappings.InvalidateAll(memmap.InvalidateOpts{}) + + // Remove mappings from the old Mappable. + for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + for m := range seg.Value() { + lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) + } + } + d.lowerMappings.RemoveAll() + } + atomic.StoreUint32(&d.copiedUp, 1) return nil } + +// copyXattrsLocked copies a subset of lower's extended attributes to upper. +// Attributes that configure an overlay in the lower are not copied up. +// +// Preconditions: d.copyMu must be locked for writing. +func (d *dentry) copyXattrsLocked(ctx context.Context) error { + vfsObj := d.fs.vfsfs.VirtualFilesystem() + lowerPop := &vfs.PathOperation{Root: d.lowerVDs[0], Start: d.lowerVDs[0]} + upperPop := &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD} + + lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0) + if err != nil { + if err == syserror.EOPNOTSUPP { + // There are no guarantees as to the contents of lowerXattrs. + return nil + } + ctx.Infof("failed to copy up xattrs because ListXattrAt failed: %v", err) + return err + } + + for _, name := range lowerXattrs { + // Do not copy up overlay attributes. + if isOverlayXattr(name) { + continue + } + + value, err := vfsObj.GetXattrAt(ctx, d.fs.creds, lowerPop, &vfs.GetXattrOptions{Name: name, Size: 0}) + if err != nil { + ctx.Infof("failed to copy up xattrs because GetXattrAt failed: %v", err) + return err + } + + if err := vfsObj.SetXattrAt(ctx, d.fs.creds, upperPop, &vfs.SetXattrOptions{Name: name, Value: value}); err != nil { + ctx.Infof("failed to copy up xattrs because SetXattrAt failed: %v", err) + return err + } + } + return nil +} + +// copyUpDescendantsLocked ensures that all descendants of d are copied up. +// +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) error { + dirents, err := d.getDirentsLocked(ctx) + if err != nil { + return err + } + for _, dirent := range dirents { + if dirent.Name == "." || dirent.Name == ".." { + continue + } + child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + if err != nil { + return err + } + if err := child.copyUpLocked(ctx); err != nil { + return err + } + if child.isDir() { + child.dirMu.Lock() + err := child.copyUpDescendantsLocked(ctx, ds) + child.dirMu.Unlock() + if err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go index f5c2462a5..df4492346 100644 --- a/pkg/sentry/fsimpl/overlay/directory.go +++ b/pkg/sentry/fsimpl/overlay/directory.go @@ -29,7 +29,9 @@ func (d *dentry) isDir() bool { return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR } -// Preconditions: d.dirMu must be locked. d.isDir(). +// Preconditions: +// * d.dirMu must be locked. +// * d.isDir(). func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string]bool, error) { vfsObj := d.fs.vfsfs.VirtualFilesystem() var readdirErr error @@ -46,12 +48,12 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string readdirErr = err return false } - defer layerFD.DecRef() + defer layerFD.DecRef(ctx) // Reuse slice allocated for maybeWhiteouts from a previous layer to // reduce allocations. maybeWhiteouts = maybeWhiteouts[:0] - if err := layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { + err = layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { if dirent.Name == "." || dirent.Name == ".." { return nil } @@ -68,7 +70,8 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string } // Non-whiteout file in the directory prevents rmdir. return syserror.ENOTEMPTY - })); err != nil { + })) + if err != nil { readdirErr = err return false } @@ -97,26 +100,29 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string return whiteouts, readdirErr } +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 dirents []vfs.Dirent } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(ctx context.Context) { } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + d := fd.dentry() + defer d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent) + fd.mu.Lock() defer fd.mu.Unlock() - d := fd.dentry() if fd.dirents == nil { ds, err := d.getDirents(ctx) if err != nil { @@ -140,7 +146,14 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { defer d.fs.renameMu.RUnlock() d.dirMu.Lock() defer d.dirMu.Unlock() + return d.getDirentsLocked(ctx) +} +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +func (d *dentry) getDirentsLocked(ctx context.Context) ([]vfs.Dirent, error) { if d.dirents != nil { return d.dirents, nil } @@ -177,12 +190,12 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { readdirErr = err return false } - defer layerFD.DecRef() + defer layerFD.DecRef(ctx) // Reuse slice allocated for maybeWhiteouts from a previous layer to // reduce allocations. maybeWhiteouts = maybeWhiteouts[:0] - if err := layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { + err = layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { if dirent.Name == "." || dirent.Name == ".." { return nil } @@ -201,7 +214,8 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { dirent.NextOff = int64(len(dirents) + 1) dirents = append(dirents, dirent) return nil - })); err != nil { + })) + if err != nil { readdirErr = err return false } @@ -282,6 +296,6 @@ func (fd *directoryFD) Sync(ctx context.Context) error { return err } err = upperFD.Sync(ctx) - upperFD.DecRef() + upperFD.DecRef(ctx) return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index ff82e1f20..bd11372d5 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -15,6 +15,8 @@ package overlay import ( + "fmt" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,10 +29,15 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// _OVL_XATTR_PREFIX is an extended attribute key prefix to identify overlayfs +// attributes. +// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_PREFIX +const _OVL_XATTR_PREFIX = linux.XATTR_TRUSTED_PREFIX + "overlay." + // _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for // opaque directories. // Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE -const _OVL_XATTR_OPAQUE = "trusted.overlay.opaque" +const _OVL_XATTR_OPAQUE = _OVL_XATTR_PREFIX + "opaque" func isWhiteout(stat *linux.Statx) bool { return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0 @@ -77,7 +84,7 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { return @@ -85,20 +92,20 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) { if len(**ds) != 0 { fs.renameMu.Lock() for _, d := range **ds { - d.checkDropLocked() + d.checkDropLocked(ctx) } fs.renameMu.Unlock() } putDentrySlice(*ds) } -func (fs *filesystem) renameMuUnlockAndCheckDrop(ds **[]*dentry) { +func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { if *ds == nil { fs.renameMu.Unlock() return } for _, d := range **ds { - d.checkDropLocked() + d.checkDropLocked(ctx) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -110,8 +117,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ds **[]*dentry) { // Dentries which may have a reference count of zero, and which therefore // should be dropped once traversal is complete, are appended to ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR @@ -126,13 +135,13 @@ afterSymlink: return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() return d, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } rp.Advance() @@ -142,7 +151,7 @@ afterSymlink: if err != nil { return nil, err } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { @@ -159,7 +168,9 @@ afterSymlink: return child, nil } -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { return child, nil @@ -177,7 +188,9 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s return child, nil } -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { childPath := fspath.Parse(name) child := fs.newDentry() @@ -199,6 +212,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str lookupErr = err return false } + defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) if !existsOnAnyLayer { @@ -237,6 +251,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } // Update child to include this layer. + childVD.IncRef() if isUpper { child.upperVD = childVD child.copiedUp = 1 @@ -261,10 +276,10 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str // Directories are merged with directories from lower layers if they // are not explicitly opaque. - opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{ + opaqueVal, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{ Root: childVD, Start: childVD, - }, &vfs.GetxattrOptions{ + }, &vfs.GetXattrOptions{ Name: _OVL_XATTR_OPAQUE, Size: 1, }) @@ -272,11 +287,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str }) if lookupErr != nil { - child.destroyLocked() + child.destroyLocked(ctx) return nil, lookupErr } if !existsOnAnyLayer { - child.destroyLocked() + child.destroyLocked(ctx) return nil, syserror.ENOENT } @@ -300,7 +315,9 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str // lookupLayerLocked is similar to lookupLocked, but only returns information // about the file rather than a dentry. // -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, name string) (lookupLayer, error) { childPath := fspath.Parse(name) lookupLayer := lookupLayerNone @@ -385,7 +402,9 @@ func (ll lookupLayer) existsInOverlay() bool { // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -425,12 +444,13 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // doCreateAt checks that creating a file at rp is permitted, then invokes // create to do so. // -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). +// Preconditions: +// * !rp.Done(). +// * For the final path component in rp, !rp.ShouldFollowSymlink(). func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -479,7 +499,13 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err := create(parent, name, childLayer == lookupLayerUpperWhiteout); err != nil { return err } + parent.dirents = nil + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */) return nil } @@ -493,7 +519,7 @@ func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFil func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) { if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)) } } @@ -501,7 +527,7 @@ func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.V func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -513,7 +539,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -532,7 +558,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -553,7 +579,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -605,12 +631,13 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop) } return err } + old.watches.Notify(ctx, "", linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent, false /* unlinked */) return nil }) } @@ -644,7 +671,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }, }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -654,12 +681,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // There may be directories on lower layers (previously hidden by // the whiteout) that the new directory should not be merged with. // Mark it opaque to prevent merging. - if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{ + if err := vfsObj.SetXattrAt(ctx, fs.creds, &pop, &vfs.SetXattrOptions{ Name: _OVL_XATTR_OPAQUE, Value: "y", }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)) } else { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -703,7 +730,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -717,17 +744,36 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { mayCreate := opts.Flags&linux.O_CREAT != 0 mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL) + mayWrite := vfs.AccessTypesForOpenFlags(&opts).MayWrite() var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + unlocked := false + unlock := func() { + if !unlocked { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + unlocked = true + } + } + defer unlock() start := rp.Start().Impl().(*dentry) if rp.Done() { + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } if mustCreate { return nil, syserror.EEXIST } - return start.openLocked(ctx, rp, &opts) + if mayWrite { + if err := start.copyUpLocked(ctx); err != nil { + return nil, err + } + } + start.IncRef() + defer start.DecRef(ctx) + unlock() + return start.openCopiedUp(ctx, rp, &opts) } afterTrailingSymlink: @@ -739,6 +785,10 @@ afterTrailingSymlink: if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } + // Reject attempts to open directories with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } // Determine whether or not we need to create a file. parent.dirMu.Lock() child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) @@ -747,12 +797,11 @@ afterTrailingSymlink: parent.dirMu.Unlock() return fd, err } + parent.dirMu.Unlock() if err != nil { - parent.dirMu.Unlock() return nil, err } // Open existing child or follow symlink. - parent.dirMu.Unlock() if mustCreate { return nil, syserror.EEXIST } @@ -767,20 +816,27 @@ afterTrailingSymlink: start = parent goto afterTrailingSymlink } - return child.openLocked(ctx, rp, &opts) + if rp.MustBeDir() && !child.isDir() { + return nil, syserror.ENOTDIR + } + if mayWrite { + if err := child.copyUpLocked(ctx); err != nil { + return nil, err + } + } + child.IncRef() + defer child.DecRef(ctx) + unlock() + return child.openCopiedUp(ctx, rp, &opts) } -// Preconditions: fs.renameMu must be locked. -func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +// Preconditions: If vfs.AccessTypesForOpenFlags(opts).MayWrite(), then d has +// been copied up. +func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } - if ats.MayWrite() { - if err := d.copyUpLocked(ctx); err != nil { - return nil, err - } - } mnt := rp.Mount() // Directory FDs open FDs from each layer when directory entries are read, @@ -792,7 +848,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, syserror.EISDIR } // Can't open directories writably. - if ats&vfs.MayWrite != 0 { + if ats.MayWrite() { return nil, syserror.EISDIR } if opts.Flags&linux.O_DIRECT != 0 { @@ -825,14 +881,15 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf fd.LockFD.Init(&d.locks) layerFDOpts := layerFD.Options() if err := fd.vfsfd.Init(fd, layerFlags, mnt, &d.vfsd, &layerFDOpts); err != nil { - layerFD.DecRef() + layerFD.DecRef(ctx) return nil, err } return &fd.vfsfd, nil } -// Preconditions: parent.dirMu must be locked. parent does not already contain -// a child named rp.Component(). +// Preconditions: +// * parent.dirMu must be locked. +// * parent does not already contain a child named rp.Component(). func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { creds := rp.Credentials() if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil { @@ -893,7 +950,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -904,7 +961,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving child, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -920,11 +977,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving fd.LockFD.Init(&child.locks) upperFDOpts := upperFD.Options() if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil { - upperFD.DecRef() + upperFD.DecRef(ctx) // Don't bother with cleanup; the file was created successfully, we // just can't open it anymore for some reason. return nil, err } + parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */) return &fd.vfsfd, nil } @@ -932,7 +990,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -952,7 +1010,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa var ds *[]*dentry fs.renameMu.Lock() - defer fs.renameMuUnlockAndCheckDrop(&ds) + defer fs.renameMuUnlockAndCheckDrop(ctx, &ds) newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds) if err != nil { return err @@ -970,16 +1028,231 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } defer mnt.EndWrite() - // FIXME(gvisor.dev/issue/1199): Actually implement rename. - _ = newParent - return syserror.EXDEV + oldParent := oldParentVD.Dentry().Impl().(*dentry) + creds := rp.Credentials() + if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a + // directory, we need to check for write permission on it. + oldParent.dirMu.Lock() + defer oldParent.dirMu.Unlock() + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + if err != nil { + return err + } + if err := vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&oldParent.mode)), auth.KUID(atomic.LoadUint32(&renamed.uid))); err != nil { + return err + } + if renamed.isDir() { + if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { + return syserror.EINVAL + } + if oldParent != newParent { + if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + } + } else { + if opts.MustBeDir || rp.MustBeDir() { + return syserror.ENOTDIR + } + } + + if oldParent != newParent { + if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + newParent.dirMu.Lock() + defer newParent.dirMu.Unlock() + } + if newParent.vfsd.IsDead() { + return syserror.ENOENT + } + replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) + if err != nil { + return err + } + var ( + replaced *dentry + replacedVFSD *vfs.Dentry + whiteouts map[string]bool + ) + if replacedLayer.existsInOverlay() { + replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil { + return err + } + replacedVFSD = &replaced.vfsd + if replaced.isDir() { + if !renamed.isDir() { + return syserror.EISDIR + } + if genericIsAncestorDentry(replaced, renamed) { + return syserror.ENOTEMPTY + } + replaced.dirMu.Lock() + defer replaced.dirMu.Unlock() + whiteouts, err = replaced.collectWhiteoutsForRmdirLocked(ctx) + if err != nil { + return err + } + } else { + if rp.MustBeDir() || renamed.isDir() { + return syserror.ENOTDIR + } + } + } + + if oldParent == newParent && oldName == newName { + return nil + } + + // renamed and oldParent need to be copied-up before they're renamed on the + // upper layer. + if err := renamed.copyUpLocked(ctx); err != nil { + return err + } + // If renamed is a directory, all of its descendants need to be copied-up + // before they're renamed on the upper layer. + if renamed.isDir() { + if err := renamed.copyUpDescendantsLocked(ctx, &ds); err != nil { + return err + } + } + // newParent must be copied-up before it can contain renamed on the upper + // layer. + if err := newParent.copyUpLocked(ctx); err != nil { + return err + } + // If replaced exists, it doesn't need to be copied-up, but we do need to + // serialize with copy-up. Holding renameMu for writing should be + // sufficient, but out of an abundance of caution... + if replaced != nil { + replaced.copyMu.RLock() + defer replaced.copyMu.RUnlock() + } + + vfsObj := rp.VirtualFilesystem() + mntns := vfs.MountNamespaceFromContext(ctx) + defer mntns.DecRef(ctx) + if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil { + return err + } + + newpop := vfs.PathOperation{ + Root: newParent.upperVD, + Start: newParent.upperVD, + Path: fspath.Parse(newName), + } + + needRecreateWhiteouts := false + cleanupRecreateWhiteouts := func() { + if !needRecreateWhiteouts { + return + } + for whiteoutName, whiteoutUpper := range whiteouts { + if !whiteoutUpper { + continue + } + if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{ + Root: replaced.upperVD, + Start: replaced.upperVD, + Path: fspath.Parse(whiteoutName), + }); err != nil && err != syserror.EEXIST { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err)) + } + } + } + if renamed.isDir() { + if replacedLayer == lookupLayerUpper { + // Remove whiteouts from the directory being replaced. + needRecreateWhiteouts = true + for whiteoutName, whiteoutUpper := range whiteouts { + if !whiteoutUpper { + continue + } + if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{ + Root: replaced.upperVD, + Start: replaced.upperVD, + Path: fspath.Parse(whiteoutName), + }); err != nil { + cleanupRecreateWhiteouts() + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + } + } else if replacedLayer == lookupLayerUpperWhiteout { + // We need to explicitly remove the whiteout since otherwise rename + // on the upper layer will fail with ENOTDIR. + if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil { + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + } + } + + // Essentially no gVisor filesystem supports RENAME_WHITEOUT, so just do a + // regular rename and create the whiteout at the origin manually. Unlike + // RENAME_WHITEOUT, this isn't atomic with respect to other users of the + // upper filesystem, but this is already the case for virtually all other + // overlay filesystem operations too. + oldpop := vfs.PathOperation{ + Root: oldParent.upperVD, + Start: oldParent.upperVD, + Path: fspath.Parse(oldName), + } + if err := vfsObj.RenameAt(ctx, creds, &oldpop, &newpop, &opts); err != nil { + cleanupRecreateWhiteouts() + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + + // Below this point, the renamed dentry is now at newpop, and anything we + // replaced is gone forever. Commit the rename, update the overlay + // filesystem tree, and abandon attempts to recover from errors. + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) + delete(oldParent.children, oldName) + if replaced != nil { + ds = appendDentry(ds, replaced) + } + if oldParent != newParent { + newParent.dirents = nil + // This can't drop the last reference on oldParent because one is held + // by oldParentVD, so lock recursion is impossible. + oldParent.DecRef(ctx) + ds = appendDentry(ds, oldParent) + newParent.IncRef() + renamed.parent = newParent + } + renamed.name = newName + if newParent.children == nil { + newParent.children = make(map[string]*dentry) + } + newParent.children[newName] = renamed + oldParent.dirents = nil + + if err := fs.createWhiteout(ctx, vfsObj, &oldpop); err != nil { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout at origin after RenameAt: %v", err)) + } + if renamed.isDir() { + if err := vfsObj.SetXattrAt(ctx, fs.creds, &newpop, &vfs.SetXattrOptions{ + Name: _OVL_XATTR_OPAQUE, + Value: "y", + }); err != nil { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to make renamed directory opaque: %v", err)) + } + } + + vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir()) + return nil } // RmdirAt implements vfs.FilesystemImpl.RmdirAt. func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -1001,7 +1274,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -1051,7 +1324,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error Start: child.upperVD, Path: fspath.Parse(whiteoutName), }); err != nil && err != syserror.EEXIST { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)) } } } @@ -1081,15 +1354,14 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Don't attempt to recover from this: the original directory is // already gone, so any dentries representing it are invalid, and // creating a new directory won't undo that. - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err) - vfsObj.AbortDeleteDentry(&child.vfsd) - return err + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)) } - vfsObj.CommitDeleteDentry(&child.vfsd) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) delete(parent.children, name) ds = appendDentry(ds, child) parent.dirents = nil + parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0 /* cookie */, vfs.InodeEvent, true /* unlinked */) return nil } @@ -1097,14 +1369,27 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + return err + } + err = d.setStatLocked(ctx, rp, opts) + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + if err != nil { return err } + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ctx, ev, 0 /* cookie */, vfs.InodeEvent) + } + return nil +} + +// Precondition: d.fs.renameMu must be held for reading. +func (d *dentry) setStatLocked(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := rp.Mount() @@ -1132,7 +1417,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statx{}, err @@ -1160,7 +1445,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) _, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statfs{}, err @@ -1197,7 +1482,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -1211,7 +1496,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -1233,7 +1518,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -1290,70 +1575,175 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } } if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err) - if child != nil { - vfsObj.AbortDeleteDentry(&child.vfsd) - } - return err + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)) } + var cw *vfs.Watches if child != nil { - vfsObj.CommitDeleteDentry(&child.vfsd) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) delete(parent.children, name) ds = appendDentry(ds, child) + cw = &child.watches } + vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name) parent.dirents = nil return nil } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// isOverlayXattr returns whether the given extended attribute configures the +// overlay. +func isOverlayXattr(name string) bool { + return strings.HasPrefix(name, _OVL_XATTR_PREFIX) +} + +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err } - // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr, - // but not any other xattr syscalls. For now we just reject all of them. - return nil, syserror.ENOTSUP + + return fs.listXattr(ctx, d, size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +func (fs *filesystem) listXattr(ctx context.Context, d *dentry, size uint64) ([]string, error) { + vfsObj := d.fs.vfsfs.VirtualFilesystem() + top := d.topLayer() + names, err := vfsObj.ListXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, size) + if err != nil { + return nil, err + } + + // Filter out all overlay attributes. + n := 0 + for _, name := range names { + if !isOverlayXattr(name) { + names[n] = name + n++ + } + } + return names[:n], err +} + +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err } - return "", syserror.ENOTSUP + + return fs.getXattr(ctx, d, rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +func (fs *filesystem) getXattr(ctx context.Context, d *dentry, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { + return "", err + } + + // Return EOPNOTSUPP when fetching an overlay attribute. + // See fs/overlayfs/super.c:ovl_own_xattr_get(). + if isOverlayXattr(opts.Name) { + return "", syserror.EOPNOTSUPP + } + + // Analogous to fs/overlayfs/super.c:ovl_other_xattr_get(). + vfsObj := d.fs.vfsfs.VirtualFilesystem() + top := d.topLayer() + return vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, opts) +} + +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + return err + } + + err = fs.setXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), &opts) + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) if err != nil { return err } - return syserror.ENOTSUP + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent) + return nil +} + +// Precondition: fs.renameMu must be locked. +func (fs *filesystem) setXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { + return err + } + + // Return EOPNOTSUPP when setting an overlay attribute. + // See fs/overlayfs/super.c:ovl_own_xattr_set(). + if isOverlayXattr(opts.Name) { + return syserror.EOPNOTSUPP + } + + // Analogous to fs/overlayfs/super.c:ovl_other_xattr_set(). + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := d.copyUpLocked(ctx); err != nil { + return err + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + return vfsObj.SetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, opts) } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + return err + } + + err = fs.removeXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), name) + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) if err != nil { return err } - return syserror.ENOTSUP + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent) + return nil +} + +// Precondition: fs.renameMu must be locked. +func (fs *filesystem) removeXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, name string) error { + if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { + return err + } + + // Like SetXattrAt, return EOPNOTSUPP when removing an overlay attribute. + // Linux passes the remove request to xattr_handler->set. + // See fs/xattr.c:vfs_removexattr(). + if isOverlayXattr(name) { + return syserror.EOPNOTSUPP + } + + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := d.copyUpLocked(ctx); err != nil { + return err + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + return vfsObj.RemoveXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index a3c1f7a8d..853aee951 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -38,6 +39,7 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { }) } +// +stateify savable type nonDirectoryFD struct { fileDescription @@ -46,7 +48,7 @@ type nonDirectoryFD struct { // fileDescription.dentry().upperVD. cachedFlags is the last known value of // cachedFD.StatusFlags(). copiedUp, cachedFD, and cachedFlags are // protected by mu. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` copiedUp bool cachedFD *vfs.FileDescription cachedFlags uint32 @@ -81,11 +83,11 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip oldOff, oldOffErr := fd.cachedFD.Seek(ctx, 0, linux.SEEK_CUR) if oldOffErr == nil { if _, err := upperFD.Seek(ctx, oldOff, linux.SEEK_SET); err != nil { - upperFD.DecRef() + upperFD.DecRef(ctx) return nil, err } } - fd.cachedFD.DecRef() + fd.cachedFD.DecRef(ctx) fd.copiedUp = true fd.cachedFD = upperFD fd.cachedFlags = statusFlags @@ -99,8 +101,8 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *nonDirectoryFD) Release() { - fd.cachedFD.DecRef() +func (fd *nonDirectoryFD) Release(ctx context.Context) { + fd.cachedFD.DecRef(ctx) fd.cachedFD = nil } @@ -121,7 +123,6 @@ func (fd *nonDirectoryFD) OnClose(ctx context.Context) error { fd.cachedFlags = statusFlags } wrappedFD := fd.cachedFD - defer wrappedFD.IncRef() fd.mu.Unlock() return wrappedFD.OnClose(ctx) } @@ -138,7 +139,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux Mask: layerMask, Sync: opts.Sync, }) - wrappedFD.DecRef() + wrappedFD.DecRef(ctx) if err != nil { return linux.Statx{}, err } @@ -147,11 +148,21 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux return stat, nil } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + return err + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Allocate(ctx, mode, offset, length) +} + // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := fd.vfsfd.Mount() @@ -173,10 +184,13 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) return err } d.updateAfterSetStatLocked(&opts) + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) + } return nil } -// StatFS implements vfs.FileDesciptionImpl.StatFS. +// StatFS implements vfs.FileDescriptionImpl.StatFS. func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.filesystem().statFS(ctx) } @@ -187,7 +201,7 @@ func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, off if err != nil { return 0, err } - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) return wrappedFD.PRead(ctx, dst, offset, opts) } @@ -209,7 +223,7 @@ func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, of if err != nil { return 0, err } - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) return wrappedFD.PWrite(ctx, src, offset, opts) } @@ -250,17 +264,112 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error { return err } wrappedFD.IncRef() - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) fd.mu.Unlock() return wrappedFD.Sync(ctx) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - wrappedFD, err := fd.getCurrentFD(ctx) + if err := fd.ensureMappable(ctx, opts); err != nil { + return err + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd.dentry(), opts) +} + +// ensureMappable ensures that fd.dentry().wrappedMappable is not nil. +func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { + d := fd.dentry() + + // Fast path if we already have a Mappable for the current top layer. + if atomic.LoadUint32(&d.isMappable) != 0 { + return nil + } + + // Only permit mmap of regular files, since other file types may have + // unpredictable behavior when mmapped (e.g. /dev/zero). + if atomic.LoadUint32(&d.mode)&linux.S_IFMT != linux.S_IFREG { + return syserror.ENODEV + } + + // Get a Mappable for the current top layer. + fd.mu.Lock() + defer fd.mu.Unlock() + d.copyMu.RLock() + defer d.copyMu.RUnlock() + if atomic.LoadUint32(&d.isMappable) != 0 { + return nil + } + wrappedFD, err := fd.currentFDLocked(ctx) if err != nil { return err } - defer wrappedFD.DecRef() - return wrappedFD.ConfigureMMap(ctx, opts) + if err := wrappedFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + // Use this Mappable for all mappings of this layer (unless we raced with + // another call to ensureMappable). + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + d.dataMu.Lock() + defer d.dataMu.Unlock() + if d.wrappedMappable == nil { + d.wrappedMappable = opts.Mappable + atomic.StoreUint32(&d.isMappable, 1) + } + return nil +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil { + return err + } + if !d.isCopiedUp() { + d.lowerMappings.AddMapping(ms, ar, offset, writable) + } + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable) + if !d.isCopiedUp() { + d.lowerMappings.RemoveMapping(ms, ar, offset, writable) + } +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil { + return err + } + if !d.isCopiedUp() { + d.lowerMappings.AddMapping(ms, dstAR, offset, writable) + } + return nil +} + +// Translate implements memmap.Mappable.Translate. +func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + d.dataMu.RLock() + defer d.dataMu.RUnlock() + return d.wrappedMappable.Translate(ctx, required, optional, at) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (d *dentry) InvalidateUnsavable(ctx context.Context) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + return d.wrappedMappable.InvalidateUnsavable(ctx) } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e720d4825..dfbccd05f 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,10 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// *** "memmap.Mappable locks" below this point +// dentry.mapsMu +// *** "memmap.Mappable locks taken by Translate" below this point +// dentry.dataMu // // Locking dentry.dirMu in multiple dentries requires that parent dentries are // locked before child dentries, and that filesystem.renameMu is locked to @@ -37,6 +41,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -46,6 +51,8 @@ import ( const Name = "overlay" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // Name implements vfs.FilesystemType.Name. @@ -55,6 +62,8 @@ func (FilesystemType) Name() string { // FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to // FilesystemType.GetFilesystem. +// +// +stateify savable type FilesystemOptions struct { // Callers passing FilesystemOptions to // overlay.FilesystemType.GetFilesystem() are responsible for ensuring that @@ -71,6 +80,8 @@ type FilesystemOptions struct { } // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -93,7 +104,7 @@ type filesystem struct { // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different // dentries. - renameMu sync.RWMutex + renameMu sync.RWMutex `state:"nosave"` // lastDirIno is the last inode number assigned to a directory. lastDirIno // is accessed using atomic memory operations. @@ -106,16 +117,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsoptsRaw := opts.InternalData fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) if fsoptsRaw != nil && !haveFSOpts { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) + ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } if haveFSOpts { if len(fsopts.LowerRoots) == 0 { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") return nil, nil, syserror.EINVAL } if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") return nil, nil, syserror.EINVAL } // We don't enforce a maximum number of lower layers when not @@ -123,7 +134,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // filesystem with any number of lower layers. } else { vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef() + defer vfsroot.DecRef(ctx) upperPathname, ok := mopts["upperdir"] if ok { delete(mopts, "upperdir") @@ -132,7 +143,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt delete(mopts, "workdir") upperPath := fspath.Parse(upperPathname) if !upperPath.Absolute { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -144,38 +155,38 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt CheckSearchable: true, }) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) return nil, nil, err } - defer upperRoot.DecRef() + defer upperRoot.DecRef(ctx) privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) return nil, nil, err } - defer privateUpperRoot.DecRef() + defer privateUpperRoot.DecRef(ctx) fsopts.UpperRoot = privateUpperRoot } lowerPathnamesStr, ok := mopts["lowerdir"] if !ok { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") return nil, nil, syserror.EINVAL } if len(lowerPathnames) > maxLowerLayers { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) return nil, nil, syserror.EINVAL } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) + ctx.Infof("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) return nil, nil, syserror.EINVAL } lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -187,21 +198,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt CheckSearchable: true, }) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef() + defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer privateLowerRoot.DecRef() + defer privateLowerRoot.DecRef(ctx) fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } if len(mopts) != 0 { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) + ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } @@ -264,19 +275,19 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Mask: rootStatMask, }) if err != nil { - root.destroyLocked() - fs.vfsfs.DecRef() + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) return nil, nil, err } if rootStat.Mask&rootStatMask != rootStatMask { - root.destroyLocked() - fs.vfsfs.DecRef() + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EREMOTE } if isWhiteout(&rootStat) { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") - root.destroyLocked() - fs.vfsfs.DecRef() + ctx.Infof("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL } root.mode = uint32(rootStat.Mode) @@ -315,21 +326,25 @@ func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forc if err != nil { return vfs.VirtualDentry{}, err } - return vfs.MakeVirtualDentry(newmnt, vd.Dentry()), nil + // Take a reference on the dentry which will be owned by the returned + // VirtualDentry. + d := vd.Dentry() + d.IncRef() + return vfs.MakeVirtualDentry(newmnt, d), nil } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { vfsObj := fs.vfsfs.VirtualFilesystem() vfsObj.PutAnonBlockDevMinor(fs.dirDevMinor) for _, lowerDevMinor := range fs.lowerDevMinors { vfsObj.PutAnonBlockDevMinor(lowerDevMinor) } if fs.opts.UpperRoot.Ok() { - fs.opts.UpperRoot.DecRef() + fs.opts.UpperRoot.DecRef(ctx) } for _, lowerRoot := range fs.opts.LowerRoots { - lowerRoot.DecRef() + lowerRoot.DecRef(ctx) } } @@ -358,6 +373,8 @@ func (fs *filesystem) newDirIno() uint64 { } // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -390,7 +407,7 @@ type dentry struct { // and dirents (if not nil) is a cache of dirents as returned by // directoryFDs representing this directory. children is protected by // dirMu. - dirMu sync.Mutex + dirMu sync.Mutex `state:"nosave"` children map[string]*dentry dirents []vfs.Dirent @@ -400,7 +417,7 @@ type dentry struct { // If !upperVD.Ok(), it can transition to a valid vfs.VirtualDentry (i.e. // be copied up) with copyMu locked for writing; otherwise, it is // immutable. lowerVDs is always immutable. - copyMu sync.RWMutex + copyMu sync.RWMutex `state:"nosave"` upperVD vfs.VirtualDentry lowerVDs []vfs.VirtualDentry @@ -415,7 +432,43 @@ type dentry struct { devMinor uint32 ino uint64 + // If this dentry represents a regular file, then: + // + // - mapsMu is used to synchronize between copy-up and memmap.Mappable + // methods on dentry preceding mm.MemoryManager.activeMu in the lock order. + // + // - dataMu is used to synchronize between copy-up and + // dentry.(memmap.Mappable).Translate. + // + // - lowerMappings tracks memory mappings of the file. lowerMappings is + // used to invalidate mappings of the lower layer when the file is copied + // up to ensure that they remain coherent with subsequent writes to the + // file. (Note that, as of this writing, Linux overlayfs does not do this; + // this feature is a gVisor extension.) lowerMappings is protected by + // mapsMu. + // + // - If this dentry is copied-up, then wrappedMappable is the Mappable + // obtained from a call to the current top layer's + // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil + // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil. + // wrappedMappable is protected by mapsMu and dataMu. + // + // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is + // accessed using atomic memory operations. + mapsMu sync.Mutex + lowerMappings memmap.MappingSet + dataMu sync.RWMutex + wrappedMappable memmap.Mappable + isMappable uint32 + locks vfs.FileLocks + + // watches is the set of inotify watches on the file repesented by this dentry. + // + // Note that hard links to the same file will not share the same set of + // watches, due to the fact that we do not have inode structures in this + // overlay implementation. + watches vfs.Watches } // newDentry creates a new dentry. The dentry initially has no references; it @@ -452,10 +505,10 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { +func (d *dentry) DecRef(ctx context.Context) { if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { d.fs.renameMu.Lock() - d.checkDropLocked() + d.checkDropLocked(ctx) d.fs.renameMu.Unlock() } else if refs < 0 { panic("overlay.dentry.DecRef() called without holding a reference") @@ -466,7 +519,7 @@ func (d *dentry) DecRef() { // becomes deleted. // // Preconditions: d.fs.renameMu must be locked for writing. -func (d *dentry) checkDropLocked() { +func (d *dentry) checkDropLocked(ctx context.Context) { // Dentries with a positive reference count must be retained. (The only way // to obtain a reference on a dentry with zero references is via path // resolution, which requires renameMu, so if d.refs is zero then it will @@ -475,15 +528,25 @@ func (d *dentry) checkDropLocked() { if atomic.LoadInt64(&d.refs) != 0 { return } + + // Make sure that we do not lose watches on dentries that have not been + // deleted. Note that overlayfs never calls VFS.InvalidateDentry(), so + // d.vfsd.IsDead() indicates that d was deleted. + if !d.vfsd.IsDead() && d.watches.Size() > 0 { + return + } + // Refs is still zero; destroy it. - d.destroyLocked() + d.destroyLocked(ctx) return } // destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. -func (d *dentry) destroyLocked() { +// Preconditions: +// * d.fs.renameMu must be locked for writing. +// * d.refs == 0. +func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: // Mark the dentry destroyed. @@ -495,12 +558,14 @@ func (d *dentry) destroyLocked() { } if d.upperVD.Ok() { - d.upperVD.DecRef() + d.upperVD.DecRef(ctx) } for _, lowerVD := range d.lowerVDs { - lowerVD.DecRef() + lowerVD.DecRef(ctx) } + d.watches.HandleDeletion(ctx) + if d.parent != nil { d.parent.dirMu.Lock() if !d.vfsd.IsDead() { @@ -510,7 +575,7 @@ func (d *dentry) destroyLocked() { // Drop the reference held by d on its parent without recursively // locking d.fs.renameMu. if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked() + d.parent.checkDropLocked(ctx) } else if refs < 0 { panic("overlay.dentry.DecRef() called without holding a reference") } @@ -518,20 +583,37 @@ func (d *dentry) destroyLocked() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) { - // TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) InotifyWithParent(ctx context.Context, events uint32, cookie uint32, et vfs.EventType) { + if d.isDir() { + events |= linux.IN_ISDIR + } + + // overlayfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates + // that d was deleted. + deleted := d.vfsd.IsDead() + + d.fs.renameMu.RLock() + // The ordering below is important, Linux always notifies the parent first. + if d.parent != nil { + d.parent.watches.Notify(ctx, d.name, events, cookie, et, deleted) + } + d.watches.Notify(ctx, "", events, cookie, et, deleted) + d.fs.renameMu.RUnlock() } // Watches implements vfs.DentryImpl.Watches. func (d *dentry) Watches() *vfs.Watches { - // TODO(gvisor.dev/issue/1479): Implement inotify. - return nil + return &d.watches } // OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) OnZeroWatches() {} +func (d *dentry) OnZeroWatches(ctx context.Context) { + if atomic.LoadInt64(&d.refs) == 0 { + d.fs.renameMu.Lock() + d.checkDropLocked(ctx) + d.fs.renameMu.Unlock() + } +} // iterLayers invokes yield on each layer comprising d, from top to bottom. If // any call to yield returns false, iterLayer stops iteration. @@ -564,6 +646,16 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + kuid := auth.KUID(atomic.LoadUint32(&d.uid)) + kgid := auth.KGID(atomic.LoadUint32(&d.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err + } + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) +} + // statInternalMask is the set of stat fields that is set by // dentry.statInternalTo(). const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -602,6 +694,8 @@ func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) { // fileDescription is embedded by overlay implementations of // vfs.FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -616,6 +710,48 @@ func (fd *fileDescription) dentry() *dentry { return fd.vfsfd.Dentry().Impl().(*dentry) } +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.filesystem().listXattr(ctx, fd.dentry(), size) +} + +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.filesystem().getXattr(ctx, fd.dentry(), auth.CredentialsFromContext(ctx), &opts) +} + +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { + fs := fd.filesystem() + d := fd.dentry() + + fs.renameMu.RLock() + err := fs.setXattrLocked(ctx, d, fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), &opts) + fs.renameMu.RUnlock() + if err != nil { + return err + } + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil +} + +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { + fs := fd.filesystem() + d := fd.dentry() + + fs.renameMu.RLock() + err := fs.removeXattrLocked(ctx, d, fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), name) + fs.renameMu.RUnlock() + if err != nil { + return err + } + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index dd7eaf4a8..4e2da4810 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type filesystemType struct{} // Name implements vfs.FilesystemType.Name. @@ -43,6 +44,7 @@ func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile panic("pipefs.filesystemType.GetFilesystem should never be called") } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -63,9 +65,9 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // PrependPath implements vfs.FilesystemImpl.PrependPath. @@ -76,6 +78,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink @@ -115,7 +119,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode.Stat. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) return linux.Statx{ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, @@ -143,12 +147,14 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth. return syserror.EPERM } -// TODO(gvisor.dev/issue/1193): kernfs does not provide a way to implement -// statfs, from which we should indicate PIPEFS_MAGIC. - // Open implements kernfs.Inode.Open. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks) +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + return i.pipe.Open(ctx, rp.Mount(), d.VFSDentry(), opts.Flags, &i.locks) +} + +// StatFS implements kernfs.Inode.StatFS. +func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.PIPEFS_MAGIC), nil } // NewConnectedPipeFDs returns a pair of FileDescriptions representing the read @@ -160,6 +166,6 @@ func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vf inode := newInode(ctx, fs) var d kernfs.Dentry d.Init(inode) - defer d.DecRef() + defer d.DecRef(ctx) return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags) } diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 6014138ff..2e086e34c 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,18 +1,79 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "fd_dir_inode_refs", + out = "fd_dir_inode_refs.go", + package = "proc", + prefix = "fdDirInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "fdDirInode", + }, +) + +go_template_instance( + name = "fd_info_dir_inode_refs", + out = "fd_info_dir_inode_refs.go", + package = "proc", + prefix = "fdInfoDirInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "fdInfoDirInode", + }, +) + +go_template_instance( + name = "subtasks_inode_refs", + out = "subtasks_inode_refs.go", + package = "proc", + prefix = "subtasksInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "subtasksInode", + }, +) + +go_template_instance( + name = "task_inode_refs", + out = "task_inode_refs.go", + package = "proc", + prefix = "taskInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "taskInode", + }, +) + +go_template_instance( + name = "tasks_inode_refs", + out = "tasks_inode_refs.go", + package = "proc", + prefix = "tasksInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "tasksInode", + }, +) + go_library( name = "proc", srcs = [ + "fd_dir_inode_refs.go", + "fd_info_dir_inode_refs.go", "filesystem.go", "subtasks.go", + "subtasks_inode_refs.go", "task.go", "task_fds.go", "task_files.go", + "task_inode_refs.go", "task_net.go", "tasks.go", "tasks_files.go", + "tasks_inode_refs.go", "tasks_sys.go", ], visibility = ["//pkg/sentry:internal"], @@ -36,8 +97,10 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 609210253..05d7948ea 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -41,6 +41,7 @@ func (FilesystemType) Name() string { return Name } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -77,13 +78,15 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // dynamicInode is an overfitted interface for common Inodes with // dynamicByteSource types used in procfs. +// +// +stateify savable type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource @@ -99,6 +102,7 @@ func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux. return d } +// +stateify savable type staticFile struct { kernfs.DynamicBytesFile vfs.StaticData @@ -110,8 +114,24 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } +func newStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry { + return kernfs.NewStaticDir(creds, devMajor, devMinor, ino, perm, children, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) +} + // InternalData contains internal data passed in to the procfs mount via // vfs.GetFilesystemOptions.InternalData. +// +// +stateify savable type InternalData struct { Cgroups map[string]string } + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.PROC_SUPER_MAGIC), nil +} diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index 36a89540c..47ecd941c 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -31,11 +31,13 @@ import ( // // +stateify savable type subtasksInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid + subtasksInodeRefs locks vfs.FileLocks @@ -57,6 +59,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, // Note: credentials are overridden by taskOwnedInode. subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + subInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: subInode, owner: task} dentry := &kernfs.Dentry{} @@ -65,8 +68,8 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, return dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { tid, err := strconv.ParseUint(name, 10, 32) if err != nil { return nil, syserror.ENOENT @@ -79,12 +82,10 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, e if subTask.ThreadGroup() != i.task.ThreadGroup() { return nil, syserror.ENOENT } - - subTaskDentry := i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers) - return subTaskDentry.VFSDentry(), nil + return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers), nil } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { @@ -115,6 +116,7 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb return offset, nil } +// +stateify savable type subtasksFD struct { kernfs.GenericDirectoryFD @@ -128,7 +130,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac return fd.GenericDirectoryFD.IterDirents(ctx, cb) } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { if fd.task.ExitState() >= kernel.TaskExitZombie { return 0, syserror.ENOENT @@ -152,21 +154,23 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro return fd.GenericDirectoryFD.SetStat(ctx, opts) } -// Open implements kernfs.Inode. -func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +// Open implements kernfs.Inode.Open. +func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &subtasksFD{task: i.task} - if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil { + if err := fd.Init(&i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }); err != nil { return nil, err } - if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// Stat implements kernfs.Inode. -func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +// Stat implements kernfs.Inode.Stat. +func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } @@ -176,7 +180,12 @@ func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux. return stat, nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } + +// DecRef implements kernfs.Inode.DecRef. +func (i *subtasksInode) DecRef(context.Context) { + i.subtasksInodeRefs.DecRef(i.Destroy) +} diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 8bb2b0ce1..a7cd6f57e 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -32,11 +32,13 @@ import ( // // +stateify savable type taskInode struct { - kernfs.InodeNotSymlink + implStatFS + kernfs.InodeAttrs kernfs.InodeDirectoryNoNewChildren kernfs.InodeNoDynamicLookup - kernfs.InodeAttrs + kernfs.InodeNotSymlink kernfs.OrderedChildren + taskInodeRefs locks vfs.FileLocks @@ -51,6 +53,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}), "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), "comm": fs.newComm(task, fs.NextIno(), 0444), + "cwd": fs.newCwdSymlink(task, fs.NextIno()), "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), "exe": fs.newExeSymlink(task, fs.NextIno()), "fd": fs.newFDDirInode(task), @@ -84,6 +87,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: taskInode, owner: task} dentry := &kernfs.Dentry{} @@ -103,22 +107,31 @@ func (i *taskInode) Valid(ctx context.Context) bool { return i.task.ExitState() != kernel.TaskExitDead } -// Open implements kernfs.Inode. -func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) +// Open implements kernfs.Inode.Open. +func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } +// DecRef implements kernfs.Inode.DecRef. +func (i *taskInode) DecRef(context.Context) { + i.taskInodeRefs.DecRef(i.Destroy) +} + // taskOwnedInode implements kernfs.Inode and overrides inode owner with task // effective user and group. +// +// +stateify savable type taskOwnedInode struct { kernfs.Inode @@ -142,7 +155,10 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. dir := &kernfs.StaticDirectory{} // Note: credentials are overridden by taskOwnedInode. - dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) + dir.EnableLeakCheck() inode := &taskOwnedInode{Inode: dir, owner: task} d := &kernfs.Dentry{} @@ -155,9 +171,9 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. return d } -// Stat implements kernfs.Inode. -func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.Inode.Stat(fs, opts) +// Stat implements kernfs.Inode.Stat. +func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.Inode.Stat(ctx, fs, opts) if err != nil { return linux.Statx{}, err } @@ -173,7 +189,7 @@ func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.S return stat, nil } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { mode := i.Mode() uid, gid := i.getOwner(mode) diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index fea29e5f0..0866cea2b 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -43,15 +42,16 @@ func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) return file, flags } -func taskFDExists(t *kernel.Task, fd int32) bool { +func taskFDExists(ctx context.Context, t *kernel.Task, fd int32) bool { file, _ := getTaskFD(t, fd) if file == nil { return false } - file.DecRef() + file.DecRef(ctx) return true } +// +stateify savable type fdDir struct { locks vfs.FileLocks @@ -63,12 +63,12 @@ type fdDir struct { produceSymlink bool } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.GetFDs() + fds = fdTable.GetFDs(ctx) } }) @@ -87,26 +87,33 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off Name: strconv.FormatUint(uint64(fd), 10), Type: typ, Ino: i.fs.NextIno(), - NextOff: offset + 1, + NextOff: int64(fd) + 3, } if err := cb.Handle(dirent); err != nil { - return offset, err + // Getdents should iterate correctly despite mutation + // of fds, so we return the next fd to serialize plus + // 2 (which accounts for the "." and ".." tracked by + // kernfs) as the offset. + return int64(fd) + 2, err } - offset++ } - return offset, nil + // We serialized them all. Next offset should be higher than last + // serialized fd. + return int64(fds[len(fds)-1]) + 3, nil } // fdDirInode represents the inode for /proc/[pid]/fd directory. // // +stateify savable type fdDirInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + fdDir + fdDirInodeRefs + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid - fdDir } var _ kernfs.Inode = (*fdDirInode)(nil) @@ -120,6 +127,7 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry { }, } inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.EnableLeakCheck() dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -128,30 +136,31 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry { return dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *fdDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { return nil, syserror.ENOENT } fd := int32(fdInt) - if !taskFDExists(i.task, fd) { + if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } - taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()) - return taskDentry.VFSDentry(), nil + return i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()), nil } -// Open implements kernfs.Inode. -func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) +// Open implements kernfs.Inode.Open. +func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. // // This is to match Linux, which uses a special permission handler to guarantee // that a process can still access /proc/self/fd after it has executed @@ -173,10 +182,16 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia return err } +// DecRef implements kernfs.Inode.DecRef. +func (i *fdDirInode) DecRef(context.Context) { + i.fdDirInodeRefs.DecRef(i.Destroy) +} + // fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file. // // +stateify savable type fdSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -199,14 +214,14 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *ker return d } -func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { +func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { file, _ := getTaskFD(s.task, s.fd) if file == nil { return "", syserror.ENOENT } - defer file.DecRef() + defer file.DecRef(ctx) root := vfs.RootFromContext(ctx) - defer root.DecRef() + defer root.DecRef(ctx) return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } @@ -215,7 +230,7 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen if file == nil { return vfs.VirtualDentry{}, "", syserror.ENOENT } - defer file.DecRef() + defer file.DecRef(ctx) vd := file.VirtualDentry() vd.IncRef() return vd, "", nil @@ -225,12 +240,14 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen // // +stateify savable type fdInfoDirInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + fdDir + fdInfoDirInodeRefs + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid - fdDir } var _ kernfs.Inode = (*fdInfoDirInode)(nil) @@ -243,6 +260,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry { }, } inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.EnableLeakCheck() dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -251,39 +269,44 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry { return dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { return nil, syserror.ENOENT } fd := int32(fdInt) - if !taskFDExists(i.task, fd) { + if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } data := &fdInfoData{ task: i.task, fd: fd, } - dentry := i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data) - return dentry.VFSDentry(), nil + return i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data), nil } -// Open implements kernfs.Inode. -func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) +// Open implements kernfs.Inode.Open. +func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +// DecRef implements kernfs.Inode.DecRef. +func (i *fdInfoDirInode) DecRef(context.Context) { + i.fdInfoDirInodeRefs.DecRef(i.Destroy) +} + // fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd]. // // +stateify savable type fdInfoData struct { kernfs.DynamicBytesFile - refs.AtomicRefCount task *kernel.Task fd int32 @@ -297,7 +320,7 @@ func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { if file == nil { return syserror.ENOENT } - defer file.DecRef() + defer file.DecRef(ctx) // TODO(b/121266871): Include pos, locks, and other data. For now we only // have flags. // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 9af43b859..3fbf081a6 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -543,7 +543,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { var vss, rss, data uint64 s.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() @@ -648,6 +648,7 @@ func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset // // +stateify savable type exeSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -666,20 +667,24 @@ func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentr return d } -// Readlink implements kernfs.Inode. -func (s *exeSymlink) Readlink(ctx context.Context) (string, error) { - if !kernel.ContextCanTrace(ctx, s.task, false) { - return "", syserror.EACCES - } - - // Pull out the executable for /proc/[pid]/exe. - exec, err := s.executable() +// Readlink implements kernfs.Inode.Readlink. +func (s *exeSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { + exec, _, err := s.Getlink(ctx, nil) if err != nil { return "", err } - defer exec.DecRef() + defer exec.DecRef(ctx) + + root := vfs.RootFromContext(ctx) + if !root.Ok() { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) - return exec.PathnameWithDeleted(ctx), nil + vfsObj := exec.Mount().Filesystem().VirtualFilesystem() + name, _ := vfsObj.PathnameWithDeleted(ctx, root, exec) + return name, nil } // Getlink implements kernfs.Inode.Getlink. @@ -687,23 +692,12 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent if !kernel.ContextCanTrace(ctx, s.task, false) { return vfs.VirtualDentry{}, "", syserror.EACCES } - - exec, err := s.executable() - if err != nil { - return vfs.VirtualDentry{}, "", err - } - defer exec.DecRef() - - vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() - vd.IncRef() - return vd, "", nil -} - -func (s *exeSymlink) executable() (file fsbridge.File, err error) { if err := checkTaskState(s.task); err != nil { - return nil, err + return vfs.VirtualDentry{}, "", err } + var err error + var exec fsbridge.File s.task.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { @@ -714,12 +708,78 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) { // The MemoryManager may be destroyed, in which case // MemoryManager.destroy will simply set the executable to nil // (with locks held). - file = mm.Executable() - if file == nil { + exec = mm.Executable() + if exec == nil { err = syserror.ESRCH } }) - return + if err != nil { + return vfs.VirtualDentry{}, "", err + } + defer exec.DecRef(ctx) + + vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() + vd.IncRef() + return vd, "", nil +} + +// cwdSymlink is an symlink for the /proc/[pid]/cwd file. +// +// +stateify savable +type cwdSymlink struct { + implStatFS + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeSymlink + + task *kernel.Task +} + +var _ kernfs.Inode = (*cwdSymlink)(nil) + +func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry { + inode := &cwdSymlink{task: task} + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +// Readlink implements kernfs.Inode.Readlink. +func (s *cwdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { + cwd, _, err := s.Getlink(ctx, nil) + if err != nil { + return "", err + } + defer cwd.DecRef(ctx) + + root := vfs.RootFromContext(ctx) + if !root.Ok() { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + vfsObj := cwd.Mount().Filesystem().VirtualFilesystem() + name, _ := vfsObj.PathnameWithDeleted(ctx, root, cwd) + return name, nil +} + +// Getlink implements kernfs.Inode.Getlink. +func (s *cwdSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { + if !kernel.ContextCanTrace(ctx, s.task, false) { + return vfs.VirtualDentry{}, "", syserror.EACCES + } + if err := checkTaskState(s.task); err != nil { + return vfs.VirtualDentry{}, "", err + } + cwd := s.task.FSContext().WorkingDirectoryVFS2() + if !cwd.Ok() { + // It could have raced with process deletion. + return vfs.VirtualDentry{}, "", syserror.ESRCH + } + return cwd, "", nil } // mountInfoData is used to implement /proc/[pid]/mountinfo. @@ -748,7 +808,7 @@ func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Root has been destroyed. Don't try to read mounts. return nil } - defer rootDir.DecRef() + defer rootDir.DecRef(ctx) i.task.Kernel().VFS().GenerateProcMountInfo(ctx, rootDir, buf) return nil } @@ -779,11 +839,12 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Root has been destroyed. Don't try to read mounts. return nil } - defer rootDir.DecRef() + defer rootDir.DecRef(ctx) i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf) return nil } +// +stateify savable type namespaceSymlink struct { kernfs.StaticSymlink @@ -806,15 +867,15 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri return d } -// Readlink implements Inode. -func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) { +// Readlink implements kernfs.Inode.Readlink. +func (s *namespaceSymlink) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { if err := checkTaskState(s.task); err != nil { return "", err } - return s.StaticSymlink.Readlink(ctx) + return s.StaticSymlink.Readlink(ctx, mnt) } -// Getlink implements Inode.Getlink. +// Getlink implements kernfs.Inode.Getlink. func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { if err := checkTaskState(s.task); err != nil { return vfs.VirtualDentry{}, "", err @@ -825,13 +886,16 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir dentry.Init(&namespaceInode{}) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) vd.IncRef() - dentry.DecRef() + dentry.DecRef(ctx) return vd, "", nil } // namespaceInode is a synthetic inode created to represent a namespace in // /proc/[pid]/ns/*. +// +// +stateify savable type namespaceInode struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeNotDirectory @@ -850,12 +914,12 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32 i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } -// Open implements Inode.Open. -func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +// Open implements kernfs.Inode.Open. +func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &namespaceFD{inode: i} i.IncRef() fd.LockFD.Init(&i.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil @@ -863,6 +927,8 @@ func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd * // namespace FD is a synthetic file that represents a namespace in // /proc/[pid]/ns/*. +// +// +stateify savable type namespaceFD struct { vfs.FileDescriptionDefaultImpl vfs.LockFD @@ -873,22 +939,22 @@ type namespaceFD struct { var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil) -// Stat implements FileDescriptionImpl. +// Stat implements vfs.FileDescriptionImpl.Stat. func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(vfs, opts) + return fd.inode.Stat(ctx, vfs, opts) } -// SetStat implements FileDescriptionImpl. +// SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() creds := auth.CredentialsFromContext(ctx) return fd.inode.SetStat(ctx, vfs, creds, opts) } -// Release implements FileDescriptionImpl. -func (fd *namespaceFD) Release() { - fd.inode.DecRef() +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *namespaceFD) Release(ctx context.Context) { + fd.inode.DecRef(ctx) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 6bde27376..e7f748655 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -212,7 +212,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { continue } if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX { - s.DecRef() + s.DecRef(ctx) // Not a unix socket. continue } @@ -262,7 +262,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { // For now, we always redact this pointer. fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d", (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct. - s.Refs()-1, // RefCount, don't count our own ref. + s.ReadRefs()-1, // RefCount, don't count our own ref. 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. @@ -281,7 +281,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { } fmt.Fprintf(buf, "\n") - s.DecRef() + s.DecRef(ctx) } return nil } @@ -359,7 +359,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s)) } if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { - s.DecRef() + s.DecRef(ctx) // Not tcp4 sockets. continue } @@ -430,7 +430,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, // Field: refcount. Don't count the ref we obtain while deferencing // the weakref to this socket. - fmt.Fprintf(buf, "%d ", s.Refs()-1) + fmt.Fprintf(buf, "%d ", s.ReadRefs()-1) // Field: Socket struct address. Redacted due to the same reason as // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. @@ -455,7 +455,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, fmt.Fprintf(buf, "\n") - s.DecRef() + s.DecRef(ctx) } return nil @@ -524,7 +524,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s)) } if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { - s.DecRef() + s.DecRef(ctx) // Not udp4 socket. continue } @@ -589,7 +589,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Field: ref; reference count on the socket inode. Don't count the ref // we obtain while deferencing the weakref to this socket. - fmt.Fprintf(buf, "%d ", s.Refs()-1) + fmt.Fprintf(buf, "%d ", s.ReadRefs()-1) // Field: Socket struct address. Redacted due to the same reason as // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. @@ -600,7 +600,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "\n") - s.DecRef() + s.DecRef(ctx) } return nil } @@ -616,6 +616,7 @@ type netSnmpData struct { var _ dynamicInode = (*netSnmpData)(nil) +// +stateify savable type snmpLine struct { prefix string header string @@ -660,7 +661,7 @@ func sprintSlice(s []uint64) string { return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. } -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { types := []interface{}{ &inet.StatSNMPIP{}, @@ -709,7 +710,7 @@ type netRouteData struct { var _ dynamicInode = (*netRouteData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. // See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT") @@ -773,7 +774,7 @@ type netStatData struct { var _ dynamicInode = (*netStatData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. // See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " + diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 2f214d0c2..d8f5dd509 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -37,11 +37,13 @@ const ( // // +stateify savable type tasksInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid + tasksInodeRefs locks vfs.FileLocks @@ -50,8 +52,8 @@ type tasksInode struct { // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. - selfSymlink *vfs.Dentry - threadSelfSymlink *vfs.Dentry + selfSymlink *kernfs.Dentry + threadSelfSymlink *kernfs.Dentry // cgroupControllers is a map of controller name to directory in the // cgroup hierarchy. These controllers are immutable and will be listed @@ -79,11 +81,12 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace inode := &tasksInode{ pidns: pidns, fs: fs, - selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(), - threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(), + selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns), + threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns), cgroupControllers: cgroupControllers, } inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.EnableLeakCheck() dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -95,8 +98,8 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace return inode, dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { // Try to lookup a corresponding task. tid, err := strconv.ParseUint(name, 10, 64) if err != nil { @@ -115,11 +118,10 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return nil, syserror.ENOENT } - taskDentry := i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers) - return taskDentry.VFSDentry(), nil + return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers), nil } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 @@ -197,17 +199,19 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback return maxTaskID, nil } -// Open implements kernfs.Inode. -func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) +// Open implements kernfs.Inode.Open. +func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } @@ -224,9 +228,16 @@ func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Sta return stat, nil } +// DecRef implements kernfs.Inode.DecRef. +func (i *tasksInode) DecRef(context.Context) { + i.tasksInodeRefs.DecRef(i.Destroy) +} + // staticFileSetStat implements a special static file that allows inode // attributes to be set. This is to support /proc files that are readonly, but // allow attributes to be set. +// +// +stateify savable type staticFileSetStat struct { dynamicBytesFileSetAttr vfs.StaticData diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 7d8983aa5..f268c59b0 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -31,7 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type selfSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -50,7 +52,7 @@ func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns return d } -func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { +func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? @@ -63,17 +65,19 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { return strconv.FormatUint(uint64(tgid), 10), nil } -func (s *selfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { - target, err := s.Readlink(ctx) +func (s *selfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx, mnt) return vfs.VirtualDentry{}, target, err } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } +// +stateify savable type threadSelfSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -92,7 +96,7 @@ func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, return d } -func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { +func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? @@ -106,12 +110,12 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { return fmt.Sprintf("%d/task/%d", tgid, tid), nil } -func (s *threadSelfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { - target, err := s.Readlink(ctx) +func (s *threadSelfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx, mnt) return vfs.VirtualDentry{}, target, err } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } @@ -119,16 +123,20 @@ func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Creden // dynamicBytesFileSetAttr implements a special file that allows inode // attributes to be set. This is to support /proc files that are readonly, but // allow attributes to be set. +// +// +stateify savable type dynamicBytesFileSetAttr struct { kernfs.DynamicBytesFile } -// SetStat implements Inode.SetStat. +// SetStat implements kernfs.Inode.SetStat. func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts) } // cpuStats contains the breakdown of CPU time for /proc/stat. +// +// +stateify savable type cpuStats struct { // user is time spent in userspace tasks with non-positive niceness. user uint64 diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 6dac2afa4..3312b0418 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -25,20 +25,30 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable +type tcpMemDir int + +const ( + tcpRMem tcpMemDir = iota + tcpWMem +) + // newSysDir returns the dentry corresponding to /proc/sys directory. func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry { - return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ - "kernel": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "kernel": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}), "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)), "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)), "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)), }), - "vm": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "vm": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}), "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")), }), @@ -54,8 +64,12 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ - "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ - "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), + "ipv4": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -98,7 +112,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")), "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), }), - "core": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "core": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")), "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")), "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")), @@ -112,7 +126,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke } } - return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents) + return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for @@ -163,7 +177,7 @@ type tcpSackData struct { var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { if d.enabled == nil { sack, err := d.stack.TCPSACKEnabled() @@ -180,10 +194,11 @@ func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Tough luck. val = "1\n" } - buf.WriteString(val) - return nil + _, err := buf.WriteString(val) + return err } +// Write implements vfs.WritableDynamicBytesSource.Write. func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. @@ -199,7 +214,7 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) if err != nil { - return n, err + return 0, err } if d.enabled == nil { d.enabled = new(bool) @@ -207,3 +222,198 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset *d.enabled = v != 0 return n, d.stack.SetTCPSACKEnabled(*d.enabled) } + +// tcpRecoveryData implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/tcp_recovery. +// +// +stateify savable +type tcpRecoveryData struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` +} + +var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error { + recovery, err := d.stack.TCPRecovery() + if err != nil { + return err + } + + _, err = buf.WriteString(fmt.Sprintf("%d\n", recovery)) + return err +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit the amount of memory allocated. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + if err := d.stack.SetTCPRecovery(inet.TCPLossRecovery(v)); err != nil { + return 0, err + } + return n, nil +} + +// tcpMemData implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/tcp_rmem and /proc/sys/net/ipv4/tcp_wmem. +// +// +stateify savable +type tcpMemData struct { + kernfs.DynamicBytesFile + + dir tcpMemDir + stack inet.Stack `state:"wait"` + + // mu protects against concurrent reads/writes to FDs based on the dentry + // backing this byte source. + mu sync.Mutex `state:"nosave"` +} + +var _ vfs.WritableDynamicBytesSource = (*tcpMemData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error { + d.mu.Lock() + defer d.mu.Unlock() + + size, err := d.readSizeLocked() + if err != nil { + return err + } + _, err = buf.WriteString(fmt.Sprintf("%d\t%d\t%d\n", size.Min, size.Default, size.Max)) + return err +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + d.mu.Lock() + defer d.mu.Unlock() + + // Limit the amount of memory allocated. + src = src.TakeFirst(usermem.PageSize - 1) + size, err := d.readSizeLocked() + if err != nil { + return 0, err + } + buf := []int32{int32(size.Min), int32(size.Default), int32(size.Max)} + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, buf, src.Opts) + if err != nil { + return 0, err + } + newSize := inet.TCPBufferSize{ + Min: int(buf[0]), + Default: int(buf[1]), + Max: int(buf[2]), + } + if err := d.writeSizeLocked(newSize); err != nil { + return 0, err + } + return n, nil +} + +// Precondition: d.mu must be locked. +func (d *tcpMemData) readSizeLocked() (inet.TCPBufferSize, error) { + switch d.dir { + case tcpRMem: + return d.stack.TCPReceiveBufferSize() + case tcpWMem: + return d.stack.TCPSendBufferSize() + default: + panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir)) + } +} + +// Precondition: d.mu must be locked. +func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error { + switch d.dir { + case tcpRMem: + return d.stack.SetTCPReceiveBufferSize(size) + case tcpWMem: + return d.stack.SetTCPSendBufferSize(size) + default: + panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir)) + } +} + +// ipForwarding implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/ip_forwarding. +// +// +stateify savable +type ipForwarding struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` + enabled *bool +} + +var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error { + if ipf.enabled == nil { + enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber) + ipf.enabled = &enabled + } + + val := "0\n" + if *ipf.enabled { + // Technically, this is not quite compatible with Linux. Linux stores these + // as an integer, so if you write "2" into tcp_sack, you should get 2 back. + // Tough luck. + val = "1\n" + } + buf.WriteString(val) + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + if ipf.enabled == nil { + ipf.enabled = new(bool) + } + *ipf.enabled = v != 0 + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + return 0, err + } + return n, nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index be54897bb..6cee22823 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -20,8 +20,10 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/usermem" ) func newIPv6TestStack() *inet.TestStack { @@ -76,3 +78,72 @@ func TestIfinet6(t *testing.T) { t.Errorf("Got n.contents() = %v, want = %v", got, want) } } + +// TestIPForwarding tests the implementation of +// /proc/sys/net/ipv4/ip_forwarding +func TestConfigureIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + + file := &ipForwarding{stack: s, enabled: &c.initial} + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) + } + }) + } +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 19abb5034..6975af5a7 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -67,6 +67,7 @@ var ( taskStaticFiles = map[string]testutil.DirentType{ "auxv": linux.DT_REG, "cgroup": linux.DT_REG, + "cwd": linux.DT_LNK, "cmdline": linux.DT_REG, "comm": linux.DT_REG, "environ": linux.DT_REG, @@ -104,7 +105,7 @@ func setup(t *testing.T) *testutil.System { AllowUserMount: true, }) - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{}) + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.MountOptions{}) if err != nil { t.Fatalf("NewMountNamespace(): %v", err) } @@ -132,7 +133,7 @@ func setup(t *testing.T) *testutil.System { }, }, } - if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil { + if _, err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil { t.Fatalf("MountAt(/proc): %v", err) } return testutil.NewSystem(ctx, t, k.VFS(), mntns) @@ -218,7 +219,7 @@ func TestTasks(t *testing.T) { if err != nil { t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) buf := make([]byte, 1) bufIOSeq := usermem.BytesIOSequence(buf) if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { @@ -336,7 +337,7 @@ func TestTasksOffset(t *testing.T) { if err != nil { t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil { t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err) } @@ -441,7 +442,7 @@ func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.F t.Errorf("vfsfs.OpenAt(%v) failed: %v", absPath, err) continue } - defer child.DecRef() + defer child.DecRef(ctx) stat, err := child.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("Stat(%v) failed: %v", absPath, err) @@ -476,7 +477,7 @@ func TestTree(t *testing.T) { if err != nil { t.Fatalf("failed to create test file: %v", err) } - defer file.DecRef() + defer file.DecRef(s.Ctx) var tasks []*kernel.Task for i := 0; i < 5; i++ { @@ -501,5 +502,5 @@ func TestTree(t *testing.T) { t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err) } iterateDir(ctx, t, s, fd) - fd.DecRef() + fd.DecRef(ctx) } diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD index 067c1657f..adb610213 100644 --- a/pkg/sentry/fsimpl/signalfd/BUILD +++ b/pkg/sentry/fsimpl/signalfd/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/kernel", "//pkg/sentry/vfs", diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 242ba9b5d..10f1452ef 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -16,7 +16,6 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -26,7 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SignalFileDescription implements FileDescriptionImpl for signal fds. +// SignalFileDescription implements vfs.FileDescriptionImpl for signal fds. +// +// +stateify savable type SignalFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -43,7 +44,7 @@ type SignalFileDescription struct { target *kernel.Task // mu protects mask. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // mask is the signal mask. Protected by mu. mask linux.SignalSet @@ -54,7 +55,7 @@ var _ vfs.FileDescriptionImpl = (*SignalFileDescription)(nil) // New creates a new signal fd. func New(vfsObj *vfs.VirtualFilesystem, target *kernel.Task, mask linux.SignalSet, flags uint32) (*vfs.FileDescription, error) { vd := vfsObj.NewAnonVirtualDentry("[signalfd]") - defer vd.DecRef() + defer vd.DecRef(target) sfd := &SignalFileDescription{ target: target, mask: mask, @@ -83,7 +84,7 @@ func (sfd *SignalFileDescription) SetMask(mask linux.SignalSet) { sfd.mask = mask } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { // Attempt to dequeue relevant signals. info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0) @@ -93,8 +94,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen } // Copy out the signal info using the specified format. - var buf [128]byte - binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + infoNative := linux.SignalfdSiginfo{ Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, @@ -103,9 +103,13 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), - }) - n, err := dst.CopyOut(ctx, buf[:]) - return int64(n), err + } + n, err := infoNative.WriteTo(dst.Writer(ctx)) + if err == usermem.ErrEndOfIOSequence { + // Partial copy-out ok. + err = nil + } + return n, err } // Readiness implements waiter.Waitable.Readiness. @@ -132,5 +136,5 @@ func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) { sfd.target.SignalUnregister(entry) } -// Release implements FileDescriptionImpl.Release() -func (sfd *SignalFileDescription) Release() {} +// Release implements vfs.FileDescriptionImpl.Release. +func (sfd *SignalFileDescription) Release(context.Context) {} diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index ee0828a15..29e5371d6 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -28,14 +28,16 @@ import ( ) // filesystemType implements vfs.FilesystemType. +// +// +stateify savable type filesystemType struct{} -// GetFilesystem implements FilesystemType.GetFilesystem. +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { panic("sockfs.filesystemType.GetFilesystem should never be called") } -// Name implements FilesystemType.Name. +// Name implements vfs.FilesystemType.Name. // // Note that registering sockfs is unnecessary, except for the fact that it // will not show up under /proc/filesystems as a result. This is a very minor @@ -44,6 +46,7 @@ func (filesystemType) Name() string { return "sockfs" } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -67,9 +70,9 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // PrependPath implements vfs.FilesystemImpl.PrependPath. @@ -80,18 +83,25 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink kernfs.InodeAttrs kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink } // Open implements kernfs.Inode.Open. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return nil, syserror.ENXIO } +// StatFS implements kernfs.Inode.StatFS. +func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.SOCKFS_MAGIC), nil +} + // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index a741e2bb6..906cd52cb 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -1,21 +1,41 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "sys", + prefix = "dir", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "dir", + }, +) + go_library( name = "sys", srcs = [ + "dir_refs.go", + "kcov.go", "sys.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/coverage", + "//pkg/log", + "//pkg/refs", + "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -29,6 +49,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go new file mode 100644 index 000000000..1a6749e53 --- /dev/null +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -0,0 +1,120 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sys + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) *kernfs.Dentry { + k := &kcovInode{} + k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + d := &kernfs.Dentry{} + d.Init(k) + return d +} + +// kcovInode implements kernfs.Inode. +// +// +stateify savable +type kcovInode struct { + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + implStatFS +} + +func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + k := kernel.KernelFromContext(ctx) + if k == nil { + panic("KernelFromContext returned nil") + } + fd := &kcovFD{ + inode: i, + kcov: k.NewKcov(), + } + + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// +stateify savable +type kcovFD struct { + vfs.FileDescriptionDefaultImpl + vfs.NoLockFD + + vfsfd vfs.FileDescription + inode *kcovInode + kcov *kernel.Kcov +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *kcovFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + cmd := uint32(args[1].Int()) + arg := args[2].Uint64() + switch uint32(cmd) { + case linux.KCOV_INIT_TRACE: + return 0, fd.kcov.InitTrace(arg) + case linux.KCOV_ENABLE: + return 0, fd.kcov.EnableTrace(ctx, uint8(arg)) + case linux.KCOV_DISABLE: + if arg != 0 { + // This arg is unused; it should be 0. + return 0, syserror.EINVAL + } + return 0, fd.kcov.DisableTrace(ctx) + default: + return 0, syserror.ENOTTY + } +} + +// ConfigureMmap implements vfs.FileDescriptionImpl.ConfigureMmap. +func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.kcov.ConfigureMMap(ctx, opts) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *kcovFD) Release(ctx context.Context) { + // kcov instances have reference counts in Linux, but this seems sufficient + // for our purposes. + fd.kcov.Clear() +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *kcovFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + creds := auth.CredentialsFromContext(ctx) + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.SetStat(ctx, fs, creds, opts) +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *kcovFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return fd.inode.Stat(ctx, fd.vfsfd.Mount().Filesystem(), opts) +} diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 01ce30a4d..1568c581f 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/coverage" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -30,11 +31,16 @@ import ( // Name is the default filesystem name. const Name = "sysfs" +const defaultSysDirMode = linux.FileMode(0755) // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -57,9 +63,6 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt devMinor: devMinor, } fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - k := kernel.KernelFromContext(ctx) - maxCPUCores := k.ApplicationCores() - defaultSysDirMode := linux.FileMode(0755) root := fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ "block": fs.newDir(creds, defaultSysDirMode, nil), @@ -70,30 +73,58 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt "dev": fs.newDir(creds, defaultSysDirMode, nil), "devices": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ "system": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "cpu": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - }), + "cpu": cpuDir(ctx, fs, creds), }), }), "firmware": fs.newDir(creds, defaultSysDirMode, nil), "fs": fs.newDir(creds, defaultSysDirMode, nil), - "kernel": fs.newDir(creds, defaultSysDirMode, nil), + "kernel": kernelDir(ctx, fs, creds), "module": fs.newDir(creds, defaultSysDirMode, nil), "power": fs.newDir(creds, defaultSysDirMode, nil), }) return fs.VFSFilesystem(), root.VFSDentry(), nil } +func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry { + k := kernel.KernelFromContext(ctx) + maxCPUCores := k.ApplicationCores() + children := map[string]*kernfs.Dentry{ + "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + } + for i := uint(0); i < maxCPUCores; i++ { + children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil) + } + return fs.newDir(creds, defaultSysDirMode, children) +} + +func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry { + // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs + // should be mounted at debug/, but for our purposes, it is sufficient to + // keep it in sys. + var children map[string]*kernfs.Dentry + if coverage.KcovAvailable() { + children = map[string]*kernfs.Dentry{ + "debug": fs.newDir(creds, linux.FileMode(0700), map[string]*kernfs.Dentry{ + "kcov": fs.newKcovFile(ctx, creds), + }), + } + } + return fs.newDir(creds, defaultSysDirMode, children) +} + // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // dir implements kernfs.Inode. +// +// +stateify savable type dir struct { + dirRefs kernfs.InodeAttrs kernfs.InodeNoDynamicLookup kernfs.InodeNotSymlink @@ -109,6 +140,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte d := &dir{} d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + d.EnableLeakCheck() d.dentry.Init(d) d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents)) @@ -116,23 +148,39 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte return &d.dentry } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } // Open implements kernfs.Inode.Open. -func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +// DecRef implements kernfs.Inode.DecRef. +func (d *dir) DecRef(context.Context) { + d.dirRefs.DecRef(d.Destroy) +} + +// StatFS implements kernfs.Inode.StatFS. +func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.SYSFS_MAGIC), nil +} + // cpuFile implements kernfs.Inode. +// +// +stateify savable type cpuFile struct { + implStatFS kernfs.DynamicBytesFile + maxCores uint } @@ -149,3 +197,11 @@ func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode li d.Init(c) return d } + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.SYSFS_MAGIC), nil +} diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 242d5fd12..0a0d914cc 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -38,7 +38,7 @@ func newTestSystem(t *testing.T) *testutil.System { AllowUserMount: true, }) - mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{}) + mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.MountOptions{}) if err != nil { t.Fatalf("Failed to create new mount namespace: %v", err) } @@ -59,7 +59,7 @@ func TestReadCPUFile(t *testing.T) { if err != nil { t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) content, err := s.ReadToEnd(fd) if err != nil { t.Fatalf("Read failed: %v", err) diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 0e4053a46..400a97996 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -32,6 +32,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index c16a36cdb..1813269e0 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -62,6 +62,7 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } + kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, } @@ -73,7 +74,7 @@ func Boot() (*kernel.Kernel, error) { k.SetMemoryFile(mf) // Pass k as the platform since it is savable, unlike the actual platform. - vdso, err := loader.PrepareVDSO(nil, k) + vdso, err := loader.PrepareVDSO(k) if err != nil { return nil, fmt.Errorf("creating vdso: %v", err) } @@ -103,11 +104,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("initializing kernel: %v", err) } - kernel.VFS2Enabled = true - - if err := k.VFS().Init(); err != nil { - return nil, fmt.Errorf("VFS init: %v", err) - } k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, AllowUserList: true, @@ -126,12 +122,16 @@ func Boot() (*kernel.Kernel, error) { // CreateTask creates a new bare bones task for tests. func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) { k := kernel.KernelFromContext(ctx) + if k == nil { + return nil, fmt.Errorf("cannot find kernel from context") + } + exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root) if err != nil { return nil, err } m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) - m.SetExecutable(fsbridge.NewVFSFile(exe)) + m.SetExecutable(ctx, fsbridge.NewVFSFile(exe)) config := &kernel.TaskConfig{ Kernel: k, diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go index 0556af877..568132121 100644 --- a/pkg/sentry/fsimpl/testutil/testutil.go +++ b/pkg/sentry/fsimpl/testutil/testutil.go @@ -97,8 +97,8 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System { // Destroy release resources associated with a test system. func (s *System) Destroy() { - s.Root.DecRef() - s.MntNs.DecRef() // Reference on MntNs passed to NewSystem. + s.Root.DecRef(s.Ctx) + s.MntNs.DecRef(s.Ctx) // Reference on MntNs passed to NewSystem. } // ReadToEnd reads the contents of fd until EOF to a string. @@ -149,7 +149,7 @@ func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector { if err != nil { s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) collector := &DirentCollector{} if err := fd.IterDirents(s.Ctx, collector); err != nil { diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 2dc90d484..8853c8ad2 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -26,8 +26,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// TimerFileDescription implements FileDescriptionImpl for timer fds. It also +// TimerFileDescription implements vfs.FileDescriptionImpl for timer fds. It also // implements ktime.TimerListener. +// +// +stateify savable type TimerFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -47,9 +49,9 @@ var _ vfs.FileDescriptionImpl = (*TimerFileDescription)(nil) var _ ktime.TimerListener = (*TimerFileDescription)(nil) // New returns a new timer fd. -func New(vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) { +func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) { vd := vfsObj.NewAnonVirtualDentry("[timerfd]") - defer vd.DecRef() + defer vd.DecRef(ctx) tfd := &TimerFileDescription{} tfd.timer = ktime.NewTimer(clock, tfd) if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{ @@ -62,7 +64,7 @@ func New(vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.F return &tfd.vfsfd, nil } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { const sizeofUint64 = 8 if dst.NumBytes() < sizeofUint64 { @@ -128,8 +130,8 @@ func (tfd *TimerFileDescription) ResumeTimer() { tfd.timer.Resume() } -// Release implements FileDescriptionImpl.Release() -func (tfd *TimerFileDescription) Release() { +// Release implements vfs.FileDescriptionImpl.Release. +func (tfd *TimerFileDescription) Release(context.Context) { tfd.timer.Destroy() } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index e73732a6b..5cd428d64 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -26,6 +26,17 @@ go_template_instance( }, ) +go_template_instance( + name = "inode_refs", + out = "inode_refs.go", + package = "tmpfs", + prefix = "inode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "inode", + }, +) + go_library( name = "tmpfs", srcs = [ @@ -34,6 +45,7 @@ go_library( "directory.go", "filesystem.go", "fstree.go", + "inode_refs.go", "named_pipe.go", "regular_file.go", "socket_file.go", @@ -47,6 +59,7 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go index 2fb5c4d84..5209a17af 100644 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go @@ -83,7 +83,7 @@ func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent } err = fn(root, d) - d.DecRef() + d.DecRef(ctx) return err } @@ -105,17 +105,17 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to create mount namespace: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create nested directories with given depth. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) d := root d.IncRef() - defer d.DecRef() + defer d.DecRef(ctx) for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { @@ -125,7 +125,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - d.DecRef() + d.DecRef(ctx) d = next filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -136,7 +136,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - file.DecRef() + file.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() @@ -176,24 +176,24 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { // Create VFS. vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { b.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create nested directories with given depth. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) vd := root vd.IncRef() for i := depth; i > 0; i-- { @@ -212,7 +212,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - vd.DecRef() + vd.DecRef(ctx) vd = nextVD filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -228,12 +228,12 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) - vd.DecRef() + vd.DecRef(ctx) vd = vfs.VirtualDentry{} if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - defer fd.DecRef() + defer fd.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() @@ -278,14 +278,14 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to create mount namespace: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create and mount the submount. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil { b.Fatalf("failed to create mount point: %v", err) } @@ -293,7 +293,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount point: %v", err) } - defer mountPoint.DecRef() + defer mountPoint.DecRef(ctx) submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) if err != nil { b.Fatalf("failed to create tmpfs submount: %v", err) @@ -309,7 +309,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount root: %v", err) } - defer d.DecRef() + defer d.DecRef(ctx) for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { @@ -319,7 +319,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - d.DecRef() + d.DecRef(ctx) d = next filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -330,7 +330,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - file.DecRef() + file.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() @@ -370,24 +370,24 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { // Create VFS. vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { b.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create the mount point. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) pop := vfs.PathOperation{ Root: root, Start: root, @@ -403,9 +403,9 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount point: %v", err) } - defer mountPoint.DecRef() + defer mountPoint.DecRef(ctx) // Create and mount the submount. - if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { + if _, err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) @@ -432,7 +432,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - vd.DecRef() + vd.DecRef(ctx) vd = nextVD filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -448,11 +448,11 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) - vd.DecRef() + vd.DecRef(ctx) if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - fd.DecRef() + fd.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go index ac54d420d..9129d35b7 100644 --- a/pkg/sentry/fsimpl/tmpfs/device_file.go +++ b/pkg/sentry/fsimpl/tmpfs/device_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" ) +// +stateify savable type deviceFile struct { inode inode kind vfs.DeviceKind diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 0a1ad4765..e90669cf0 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// +stateify savable type directory struct { // Since directories can't be hard-linked, each directory can only be // associated with a single dentry, which we can store in the directory @@ -44,7 +45,7 @@ type directory struct { // (with inode == nil) that represent the iteration position of // directoryFDs. childList is used to support directoryFD.IterDirents() // efficiently. childList is protected by iterMu. - iterMu sync.Mutex + iterMu sync.Mutex `state:"nosave"` childList dentryList } @@ -57,8 +58,9 @@ func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.Fi return dir } -// Preconditions: filesystem.mu must be locked for writing. dir must not -// already contain a child with the given name. +// Preconditions: +// * filesystem.mu must be locked for writing. +// * dir must not already contain a child with the given name. func (dir *directory) insertChildLocked(child *dentry, name string) { child.parent = &dir.dentry child.name = name @@ -85,6 +87,7 @@ func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error { return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid))) } +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl @@ -95,7 +98,7 @@ type directoryFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(ctx context.Context) { if fd.iter != nil { dir := fd.inode().impl.(*directory) dir.iterMu.Lock() @@ -110,7 +113,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fs := fd.filesystem() dir := fd.inode().impl.(*directory) - defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + defer fd.dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent) // fs.mu is required to read d.parent and dentry.name. fs.mu.RLock() diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index ed40f6b52..e39cd305b 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -39,8 +38,10 @@ func (fs *filesystem) Sync(ctx context.Context) error { // // stepLocked is loosely analogous to fs/namei.c:walk_component(). // -// Preconditions: filesystem.mu must be locked. !rp.Done(). -func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { +// Preconditions: +// * filesystem.mu must be locked. +// * !rp.Done(). +func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { dir, ok := d.inode.impl.(*directory) if !ok { return nil, syserror.ENOTDIR @@ -55,13 +56,13 @@ afterSymlink: return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() return d, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } rp.Advance() @@ -74,7 +75,7 @@ afterSymlink: if !ok { return nil, syserror.ENOENT } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { @@ -97,10 +98,12 @@ afterSymlink: // walkParentDirLocked is loosely analogous to Linux's // fs/namei.c:path_parentat(). // -// Preconditions: filesystem.mu must be locked. !rp.Done(). -func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) { +// Preconditions: +// * filesystem.mu must be locked. +// * !rp.Done(). +func walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*directory, error) { for !rp.Final() { - next, err := stepLocked(rp, d) + next, err := stepLocked(ctx, rp, d) if err != nil { return nil, err } @@ -118,10 +121,10 @@ func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) { // resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). // // Preconditions: filesystem.mu must be locked. -func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) { +func resolveLocked(ctx context.Context, rp *vfs.ResolvingPath) (*dentry, error) { d := rp.Start().Impl().(*dentry) for !rp.Done() { - next, err := stepLocked(rp, d) + next, err := stepLocked(ctx, rp, d) if err != nil { return nil, err } @@ -139,12 +142,13 @@ func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) { // doCreateAt is loosely analogous to a conjunction of Linux's // fs/namei.c:filename_create() and done_path_create(). // -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error { +// Preconditions: +// * !rp.Done(). +// * For the final path component in rp, !rp.ShouldFollowSymlink(). +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error { fs.mu.Lock() defer fs.mu.Unlock() - parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -182,7 +186,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if dir { ev |= linux.IN_ISDIR } - parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parentDir.inode.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) parentDir.inode.touchCMtime() return nil } @@ -191,7 +195,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return err } @@ -202,7 +206,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } @@ -222,7 +226,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { fs.mu.RLock() defer fs.mu.RUnlock() - dir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + dir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return nil, err } @@ -232,7 +236,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error { if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -251,7 +255,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EMLINK } i.incLinksLocked() - i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */) + i.watches.Notify(ctx, "", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */) parentDir.insertChildLocked(fs.newDentry(i), name) return nil }) @@ -259,7 +263,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parentDir *directory, name string) error { creds := rp.Credentials() if parentDir.inode.nlink == maxLinks { return syserror.EMLINK @@ -273,11 +277,11 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error { creds := rp.Credentials() var childInode *inode switch opts.Mode.FileType() { - case 0, linux.S_IFREG: + case linux.S_IFREG: childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) case linux.S_IFIFO: childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) @@ -307,30 +311,43 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // don't need fs.mu for writing. if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() - defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() return nil, err } + d.IncRef() + defer d.DecRef(ctx) + fs.mu.RUnlock() return d.open(ctx, rp, &opts, false /* afterCreate */) } mustCreate := opts.Flags&linux.O_EXCL != 0 start := rp.Start().Impl().(*dentry) fs.mu.Lock() - defer fs.mu.Unlock() + unlocked := false + unlock := func() { + if !unlocked { + fs.mu.Unlock() + unlocked = true + } + } + defer unlock() if rp.Done() { - // Reject attempts to open directories with O_CREAT. + // Reject attempts to open mount root directory with O_CREAT. if rp.MustBeDir() { return nil, syserror.EISDIR } if mustCreate { return nil, syserror.EEXIST } + start.IncRef() + defer start.DecRef(ctx) + unlock() return start.open(ctx, rp, &opts, false /* afterCreate */) } afterTrailingSymlink: - parentDir, err := walkParentDirLocked(rp, start) + parentDir, err := walkParentDirLocked(ctx, rp, start) if err != nil { return nil, err } @@ -364,11 +381,12 @@ afterTrailingSymlink: creds := rp.Credentials() child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)) parentDir.insertChildLocked(child, name) + unlock() fd, err := child.open(ctx, rp, &opts, true) if err != nil { return nil, err } - parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) + parentDir.inode.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) parentDir.inode.touchCMtime() return fd, nil } @@ -376,7 +394,7 @@ afterTrailingSymlink: return nil, syserror.EEXIST } // Is the file mounted over? - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } // Do we need to resolve a trailing symlink? @@ -389,13 +407,17 @@ afterTrailingSymlink: start = &parentDir.dentry goto afterTrailingSymlink } - // Open existing file. - if mustCreate { - return nil, syserror.EEXIST + if rp.MustBeDir() && !child.inode.isDir() { + return nil, syserror.ENOTDIR } + child.IncRef() + defer child.DecRef(ctx) + unlock() return child.open(ctx, rp, &opts, false) } +// Preconditions: The caller must hold no locks (since opening pipes may block +// indefinitely). func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if !afterCreate { @@ -445,7 +467,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return "", err } @@ -467,7 +489,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Resolve newParent first to verify that it's on this Mount. fs.mu.Lock() defer fs.mu.Unlock() - newParentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -555,7 +577,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) var replacedVFSD *vfs.Dentry if replaced != nil { replacedVFSD = &replaced.vfsd @@ -566,17 +588,19 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if replaced != nil { newParentDir.removeChildLocked(replaced) if replaced.inode.isDir() { - newParentDir.inode.decLinksLocked() // from replaced's ".." + // Remove links for replaced/. and replaced/.. + replaced.inode.decLinksLocked(ctx) + newParentDir.inode.decLinksLocked(ctx) } - replaced.inode.decLinksLocked() + replaced.inode.decLinksLocked(ctx) } oldParentDir.removeChildLocked(renamed) newParentDir.insertChildLocked(renamed, newName) - vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD) + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) oldParentDir.inode.touchCMtime() if oldParentDir != newParentDir { if renamed.inode.isDir() { - oldParentDir.inode.decLinksLocked() + oldParentDir.inode.decLinksLocked(ctx) newParentDir.inode.incLinksLocked() } newParentDir.inode.touchCMtime() @@ -591,7 +615,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -626,17 +650,17 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error defer mnt.EndWrite() vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } parentDir.removeChildLocked(child) - parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) + parentDir.inode.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) // Remove links for child, child/., and child/.. - child.inode.decLinksLocked() - child.inode.decLinksLocked() - parentDir.inode.decLinksLocked() - vfsObj.CommitDeleteDentry(&child.vfsd) + child.inode.decLinksLocked(ctx) + child.inode.decLinksLocked(ctx) + parentDir.inode.decLinksLocked(ctx) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) parentDir.inode.touchCMtime() return nil } @@ -644,19 +668,19 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil { - fs.mu.RUnlock() + err = d.inode.setStat(ctx, rp.Credentials(), &opts) + fs.mu.RUnlock() + if err != nil { return err } - fs.mu.RUnlock() if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - d.InotifyWithParent(ev, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } @@ -665,7 +689,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return linux.Statx{}, err } @@ -678,24 +702,15 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() defer fs.mu.RUnlock() - if _, err := resolveLocked(rp); err != nil { + if _, err := resolveLocked(ctx, rp); err != nil { return linux.Statfs{}, err } - statfs := linux.Statfs{ - Type: linux.TMPFS_MAGIC, - BlockSize: usermem.PageSize, - FragmentSize: usermem.PageSize, - NameLength: linux.NAME_MAX, - // TODO(b/29637826): Allow configuring a tmpfs size and enforce it. - Blocks: 0, - BlocksFree: 0, - } - return statfs, nil + return globalStatfs, nil } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error { creds := rp.Credentials() child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target)) parentDir.insertChildLocked(child, name) @@ -707,7 +722,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -738,7 +753,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error defer mnt.EndWrite() vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } @@ -746,20 +761,20 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error // Generate inotify events. Note that this must take place before the link // count of the child is decremented, or else the watches may be dropped // before these events are added. - vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name) + vfs.InotifyRemoveChild(ctx, &child.inode.watches, &parentDir.inode.watches, name) parentDir.removeChildLocked(child) - child.inode.decLinksLocked() - vfsObj.CommitDeleteDentry(&child.vfsd) + child.inode.decLinksLocked(ctx) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) parentDir.inode.touchCMtime() return nil } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } @@ -768,67 +783,70 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } switch impl := d.inode.impl.(type) { case *socketFile: + if impl.ep == nil { + return nil, syserror.ECONNREFUSED + } return impl.ep, nil default: return nil, syserror.ECONNREFUSED } } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } - return d.inode.listxattr(size) + return d.inode.listXattr(size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return "", err } - return d.inode.getxattr(rp.Credentials(), &opts) + return d.inode.getXattr(rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { fs.mu.RLock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { - fs.mu.RUnlock() + err = d.inode.setXattr(rp.Credentials(), &opts) + fs.mu.RUnlock() + if err != nil { return err } - fs.mu.RUnlock() - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.removexattr(rp.Credentials(), name); err != nil { - fs.mu.RUnlock() + err = d.inode.removeXattr(rp.Credentials(), name) + fs.mu.RUnlock() + if err != nil { return err } - fs.mu.RUnlock() - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } @@ -847,8 +865,16 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } if d.parent == nil { if d.name != "" { - // This must be an anonymous memfd file. + // This file must have been created by + // newUnlinkedRegularFileDescription(). In Linux, + // mm/shmem.c:__shmem_file_setup() => + // fs/file_table.c:alloc_file_pseudo() sets the created + // dentry's dentry_operations to anon_ops, for which d_dname == + // simple_dname. fs/d_path.c:simple_dname() defines the + // dentry's pathname to be its name, prefixed with "/" and + // suffixed with " (deleted)". b.PrependComponent("/" + d.name) + b.AppendString(" (deleted)") return vfs.PrependPathSyntheticError{} } return vfs.PrependPathAtNonMountRootError{} diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 739350cf0..d772db9e9 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type namedPipe struct { inode inode @@ -28,8 +29,8 @@ type namedPipe struct { } // Preconditions: -// * fs.mu must be locked. -// * rp.Mount().CheckBeginWrite() has been called successfully. +// * fs.mu must be locked. +// * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 1614f2c39..be29a2363 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -32,7 +32,7 @@ const fileName = "mypipe" func TestSeparateFDs(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the read side. This is done in a concurrently because opening // One end the pipe blocks until the other end is opened. @@ -55,13 +55,13 @@ func TestSeparateFDs(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) } - defer wfd.DecRef() + defer wfd.DecRef(ctx) rfd, ok := <-rfdchan if !ok { t.Fatalf("failed to open pipe for reading %q", fileName) } - defer rfd.DecRef() + defer rfd.DecRef(ctx) const msg = "vamos azul" checkEmpty(ctx, t, rfd) @@ -71,7 +71,7 @@ func TestSeparateFDs(t *testing.T) { func TestNonblockingRead(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the read side as nonblocking. pop := vfs.PathOperation{ @@ -85,7 +85,7 @@ func TestNonblockingRead(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) } - defer rfd.DecRef() + defer rfd.DecRef(ctx) // Open the write side. openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} @@ -93,7 +93,7 @@ func TestNonblockingRead(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) } - defer wfd.DecRef() + defer wfd.DecRef(ctx) const msg = "geh blau" checkEmpty(ctx, t, rfd) @@ -103,7 +103,7 @@ func TestNonblockingRead(t *testing.T) { func TestNonblockingWriteError(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the write side as nonblocking, which should return ENXIO. pop := vfs.PathOperation{ @@ -121,7 +121,7 @@ func TestNonblockingWriteError(t *testing.T) { func TestSingleFD(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the pipe as readable and writable. pop := vfs.PathOperation{ @@ -135,7 +135,7 @@ func TestSingleFD(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) } - defer fd.DecRef() + defer fd.DecRef(ctx) const msg = "forza blu" checkEmpty(ctx, t, fd) @@ -152,13 +152,13 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy // Create VFS. vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 1cdb46e6f..a199eb33d 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -36,12 +36,18 @@ import ( ) // regularFile is a regular (=S_IFREG) tmpfs file. +// +// +stateify savable type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. memFile *pgalloc.MemoryFile + // memoryUsageKind is the memory accounting category under which pages backing + // this regularFile's contents are accounted. + memoryUsageKind usage.MemoryKind + // mapsMu protects mappings. mapsMu sync.Mutex `state:"nosave"` @@ -62,7 +68,7 @@ type regularFile struct { writableMappingPages uint64 // dataMu protects the fields below. - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` // data maps offsets into the file to offsets into memFile that store // the file's data. @@ -86,14 +92,75 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, - seals: linux.F_SEAL_SEAL, + memFile: fs.memFile, + memoryUsageKind: usage.Tmpfs, + seals: linux.F_SEAL_SEAL, } file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode) file.inode.nlink = 1 // from parent directory return &file.inode } +// newUnlinkedRegularFileDescription creates a regular file on the tmpfs +// filesystem represented by mount and returns an FD representing that file. +// The new file is not reachable by path traversal from any other file. +// +// newUnlinkedRegularFileDescription is analogous to Linux's +// mm/shmem.c:__shmem_file_setup(). +// +// Preconditions: mount must be a tmpfs mount. +func newUnlinkedRegularFileDescription(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, name string) (*regularFileFD, error) { + fs, ok := mount.Filesystem().Impl().(*filesystem) + if !ok { + panic("tmpfs.newUnlinkedRegularFileDescription() called with non-tmpfs mount") + } + + inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) + d := fs.newDentry(inode) + defer d.DecRef(ctx) + d.name = name + + fd := ®ularFileFD{} + fd.Init(&inode.locks) + flags := uint32(linux.O_RDWR) + if err := fd.vfsfd.Init(fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return fd, nil +} + +// NewZeroFile creates a new regular file and file description as for +// mmap(MAP_SHARED | MAP_ANONYMOUS). The file has the given size and is +// initially (implicitly) filled with zeroes. +// +// Preconditions: mount must be a tmpfs mount. +func NewZeroFile(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, size uint64) (*vfs.FileDescription, error) { + // Compare mm/shmem.c:shmem_zero_setup(). + fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, "dev/zero") + if err != nil { + return nil, err + } + rf := fd.inode().impl.(*regularFile) + rf.memoryUsageKind = usage.Anonymous + rf.size = size + return &fd.vfsfd, err +} + +// NewMemfd creates a new regular file and file description as for +// memfd_create. +// +// Preconditions: mount must be a tmpfs mount. +func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) { + fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, name) + if err != nil { + return nil, err + } + if allowSeals { + fd.inode().impl.(*regularFile).seals = 0 + } + return &fd.vfsfd, nil +} + // truncate grows or shrinks the file to the given size. It returns true if the // file size was updated. func (rf *regularFile) truncate(newSize uint64) (bool, error) { @@ -226,7 +293,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap. optional.End = pgend } - cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { + cerr := rf.data.Fill(ctx, required, optional, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { // Newly-allocated pages are zeroed, so we don't need to do anything. return dsts.NumBytes(), nil }) @@ -260,17 +327,18 @@ func (*regularFile) InvalidateUnsavable(context.Context) error { return nil } +// +stateify savable type regularFileFD struct { fileDescription // off is the file offset. off is accessed using atomic memory operations. // offMu serializes operations that may mutate off. off int64 - offMu sync.Mutex + offMu sync.Mutex `state:"nosave"` } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { +func (fd *regularFileFD) Release(context.Context) { // noop } @@ -325,8 +393,15 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since @@ -334,40 +409,44 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { - return 0, syserror.EOPNOTSUPP + return 0, offset, syserror.EOPNOTSUPP } srclen := src.NumBytes() if srclen == 0 { - return 0, nil + return 0, offset, nil } f := fd.inode().impl.(*regularFile) + f.inode.mu.Lock() + defer f.inode.mu.Unlock() + // If the file is opened with O_APPEND, update offset to file size. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking f.inode.mu is sufficient for reading f.size. + offset = int64(f.size) + } if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - var err error srclen, err = vfs.CheckLimit(ctx, offset, srclen) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(srclen) - f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) - fd.inode().touchCMtimeLocked() - f.inode.mu.Unlock() + f.inode.touchCMtimeLocked() putRegularFileReadWriter(rw) - return n, err + return n, n + offset, err } // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.offMu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.offMu.Unlock() return n, err } @@ -564,7 +643,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, case gap.Ok(): // Allocate memory for the write. gapMR := gap.Range().Intersect(pgMR) - fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs) + fr, err := rw.file.memFile.Allocate(gapMR.Length(), rw.file.memoryUsageKind) if err != nil { retErr = err goto exitLoop diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go index 3ed650474..5699d5975 100644 --- a/pkg/sentry/fsimpl/tmpfs/socket_file.go +++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go @@ -21,6 +21,8 @@ import ( ) // socketFile is a socket (=S_IFSOCK) tmpfs file. +// +// +stateify savable type socketFile struct { inode inode ep transport.BoundEndpoint diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go index b0de5fabe..a102a2ee2 100644 --- a/pkg/sentry/fsimpl/tmpfs/symlink.go +++ b/pkg/sentry/fsimpl/tmpfs/symlink.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) +// +stateify savable type symlink struct { inode inode target string // immutable diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index d7f4f0779..cefec8fde 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -51,9 +51,13 @@ import ( const Name = "tmpfs" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -67,7 +71,7 @@ type filesystem struct { devMinor uint32 // mu serializes changes to the Dentry tree. - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` nextInoMinusOne uint64 // accessed using atomic memory operations } @@ -78,6 +82,8 @@ func (FilesystemType) Name() string { } // FilesystemOpts is used to pass configuration data to tmpfs. +// +// +stateify savable type FilesystemOpts struct { // RootFileType is the FileType of the filesystem root. Valid values // are: S_IFDIR, S_IFREG, and S_IFLNK. Defaults to S_IFDIR. @@ -185,7 +191,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt case linux.S_IFDIR: root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry default: - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType) } return &fs.vfsfs, &root.vfsd, nil @@ -197,11 +203,32 @@ func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *au } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// immutable +var globalStatfs = linux.Statfs{ + Type: linux.TMPFS_MAGIC, + BlockSize: usermem.PageSize, + FragmentSize: usermem.PageSize, + NameLength: linux.NAME_MAX, + + // tmpfs currently does not support configurable size limits. In Linux, + // such a tmpfs mount will return f_blocks == f_bfree == f_bavail == 0 from + // statfs(2). However, many applications treat this as having a size limit + // of 0. To work around this, claim to have a very large but non-zero size, + // chosen to ensure that BlockSize * Blocks does not overflow int64 (which + // applications may also handle incorrectly). + // TODO(b/29637826): allow configuring a tmpfs size and enforce it. + Blocks: math.MaxInt64 / usermem.PageSize, + BlocksFree: math.MaxInt64 / usermem.PageSize, + BlocksAvailable: math.MaxInt64 / usermem.PageSize, +} + // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -249,12 +276,12 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { - d.inode.decRef() +func (d *dentry) DecRef(ctx context.Context) { + d.inode.decRef(ctx) } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { if d.inode.isDir() { events |= linux.IN_ISDIR } @@ -266,9 +293,9 @@ func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { d.inode.fs.mu.RLock() // The ordering below is important, Linux always notifies the parent first. if d.parent != nil { - d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted) + d.parent.inode.watches.Notify(ctx, d.name, events, cookie, et, deleted) } - d.inode.watches.Notify("", events, cookie, et, deleted) + d.inode.watches.Notify(ctx, "", events, cookie, et, deleted) d.inode.fs.mu.RUnlock() } @@ -278,20 +305,19 @@ func (d *dentry) Watches() *vfs.Watches { } // OnZeroWatches implements vfs.Dentry.OnZeroWatches. -func (d *dentry) OnZeroWatches() {} +func (d *dentry) OnZeroWatches(context.Context) {} // inode represents a filesystem object. +// +// +stateify savable type inode struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // refs is a reference count. refs is accessed using atomic memory - // operations. - // // A reference is held on all inodes as long as they are reachable in the // filesystem tree, i.e. nlink is nonzero. This reference is dropped when // nlink reaches 0. - refs int64 + refs inodeRefs // xattrs implements extended attributes. // @@ -300,12 +326,12 @@ type inode struct { // Inode metadata. Writing multiple fields atomically requires holding // mu, othewise atomic operations can be used. - mu sync.Mutex - mode uint32 // file type and mode - nlink uint32 // protected by filesystem.mu instead of inode.mu - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - ino uint64 // immutable + mu sync.Mutex `state:"nosave"` + mode uint32 // file type and mode + nlink uint32 // protected by filesystem.mu instead of inode.mu + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + ino uint64 // immutable // Linux's tmpfs has no concept of btime. atime int64 // nanoseconds @@ -327,7 +353,6 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth panic("file type is required in FileMode") } i.fs = fs - i.refs = 1 i.mode = uint32(mode) i.uid = uint32(kuid) i.gid = uint32(kgid) @@ -339,12 +364,15 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth i.mtime = now // i.nlink initialized by caller i.impl = impl + i.refs.EnableLeakCheck() } // incLinksLocked increments i's link count. // -// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. -// i.nlink < maxLinks. +// Preconditions: +// * filesystem.mu must be locked for writing. +// * i.nlink != 0. +// * i.nlink < maxLinks. func (i *inode) incLinksLocked() { if i.nlink == 0 { panic("tmpfs.inode.incLinksLocked() called with no existing links") @@ -358,46 +386,36 @@ func (i *inode) incLinksLocked() { // decLinksLocked decrements i's link count. If the link count reaches 0, we // remove a reference on i as well. // -// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. -func (i *inode) decLinksLocked() { +// Preconditions: +// * filesystem.mu must be locked for writing. +// * i.nlink != 0. +func (i *inode) decLinksLocked(ctx context.Context) { if i.nlink == 0 { panic("tmpfs.inode.decLinksLocked() called with no existing links") } if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 { - i.decRef() + i.decRef(ctx) } } func (i *inode) incRef() { - if atomic.AddInt64(&i.refs, 1) <= 1 { - panic("tmpfs.inode.incRef() called without holding a reference") - } + i.refs.IncRef() } func (i *inode) tryIncRef() bool { - for { - refs := atomic.LoadInt64(&i.refs) - if refs == 0 { - return false - } - if atomic.CompareAndSwapInt64(&i.refs, refs, refs+1) { - return true - } - } + return i.refs.TryIncRef() } -func (i *inode) decRef() { - if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { - i.watches.HandleDeletion() +func (i *inode) decRef(ctx context.Context) { + i.refs.DecRef(func() { + i.watches.HandleDeletion(ctx) if regFile, ok := i.impl.(*regularFile); ok { // Release memory used by regFile to store data. Since regFile is // no longer usable, we don't need to grab any locks or update any // metadata. regFile.data.DropAll(regFile.memFile) } - } else if refs < 0 { - panic("tmpfs.inode.decRef() called without holding a reference") - } + }) } func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { @@ -452,7 +470,8 @@ func (i *inode) statTo(stat *linux.Statx) { } } -func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error { +func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -460,7 +479,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&i.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { return err } i.mu.Lock() @@ -557,6 +576,8 @@ func (i *inode) direntType() uint8 { return linux.DT_LNK case *socketFile: return linux.DT_SOCK + case *namedPipe: + return linux.DT_FIFO case *deviceFile: switch impl.kind { case vfs.BlockDevice: @@ -606,66 +627,59 @@ func (i *inode) touchCMtime() { i.mu.Unlock() } -// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds -// inode.mu. +// Preconditions: +// * The caller has called vfs.Mount.CheckBeginWrite(). +// * inode.mu must be locked. func (i *inode) touchCMtimeLocked() { now := i.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&i.mtime, now) atomic.StoreInt64(&i.ctime, now) } -func (i *inode) listxattr(size uint64) ([]string, error) { - return i.xattrs.Listxattr(size) +func (i *inode) listXattr(size uint64) ([]string, error) { + return i.xattrs.ListXattr(size) } -func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { - if err := i.checkPermissions(creds, vfs.MayRead); err != nil { +func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { + if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return "", syserror.EOPNOTSUPP - } - if !i.userXattrSupported() { - return "", syserror.ENODATA - } - return i.xattrs.Getxattr(opts) + return i.xattrs.GetXattr(opts) } -func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error { - if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { +func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error { + if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !i.userXattrSupported() { - return syserror.EPERM - } - return i.xattrs.Setxattr(opts) + return i.xattrs.SetXattr(opts) } -func (i *inode) removexattr(creds *auth.Credentials, name string) error { - if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { +func (i *inode) removeXattr(creds *auth.Credentials, name string) error { + if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return i.xattrs.RemoveXattr(name) +} + +func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + // We currently only support extended attributes in the user.* and + // trusted.* namespaces. See b/148380782. + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { return syserror.EOPNOTSUPP } - if !i.userXattrSupported() { - return syserror.EPERM + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err } - return i.xattrs.Removexattr(name) -} - -// Extended attributes in the user.* namespace are only supported for regular -// files and directories. -func (i *inode) userXattrSupported() bool { - filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode) - return filetype == linux.S_IFREG || filetype == linux.S_IFDIR + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) } // fileDescription is embedded by tmpfs implementations of // vfs.FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -695,81 +709,55 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) d := fd.dentry() - if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, creds, &opts); err != nil { return err } if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - d.InotifyWithParent(ev, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { - return fd.inode().listxattr(size) +// StatFS implements vfs.FileDescriptionImpl.StatFS. +func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + return globalStatfs, nil } -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { - return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts) +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.inode().listXattr(size) } -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.inode().getXattr(auth.CredentialsFromContext(ctx), &opts) +} + +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { d := fd.dentry() - if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil { + if err := d.inode.setXattr(auth.CredentialsFromContext(ctx), &opts); err != nil { return err } // Generate inotify events. - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d := fd.dentry() - if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil { + if err := d.inode.removeXattr(auth.CredentialsFromContext(ctx), name); err != nil { return err } // Generate inotify events. - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// NewMemfd creates a new tmpfs regular file and file description that can back -// an anonymous fd created by memfd_create. -func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name string) (*vfs.FileDescription, error) { - fs, ok := mount.Filesystem().Impl().(*filesystem) - if !ok { - panic("NewMemfd() called with non-tmpfs mount") - } - - // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with - // S_IRWXUGO. - inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) - rf := inode.impl.(*regularFile) - if allowSeals { - rf.seals = 0 - } - - d := fs.newDentry(inode) - defer d.DecRef() - d.name = name - - // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with - // FMODE_READ | FMODE_WRITE. - var fd regularFileFD - fd.Init(&inode.locks) - flags := uint32(linux.O_RDWR) - if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go index a240fb276..99c8e3c0f 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go @@ -34,21 +34,21 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr creds := auth.CredentialsFromContext(ctx) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) } root := mntns.Root() return vfsObj, root, func() { - root.DecRef() - mntns.DecRef() + root.DecRef(ctx) + mntns.DecRef(ctx) }, nil } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD new file mode 100644 index 000000000..0ca750281 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -0,0 +1,47 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +licenses(["notice"]) + +go_library( + name = "verity", + srcs = [ + "filesystem.go", + "verity.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/marshal/primitive", + "//pkg/merkletree", + "//pkg/sentry/arch", + "//pkg/sentry/fs/lock", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + ], +) + +go_test( + name = "verity_test", + srcs = [ + "verity_test.go", + ], + library = ":verity", + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/contexttest", + "//pkg/sentry/vfs", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go new file mode 100644 index 000000000..a560b0797 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -0,0 +1,886 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *filesystem) Sync(ctx context.Context) error { + // All files should be read-only. + return nil +} + +var dentrySlicePool = sync.Pool{ + New: func() interface{} { + ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity + return &ds + }, +} + +func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry { + if ds == nil { + ds = dentrySlicePool.Get().(*[]*dentry) + } + *ds = append(*ds, d) + return ds +} + +// Preconditions: ds != nil. +func putDentrySlice(ds *[]*dentry) { + // Allow dentries to be GC'd. + for i := range *ds { + (*ds)[i] = nil + } + *ds = (*ds)[:0] + dentrySlicePool.Put(ds) +} + +// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls +// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for +// writing. +// +// ds is a pointer-to-pointer since defer evaluates its arguments immediately, +// but dentry slices are allocated lazily, and it's much easier to say "defer +// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { +// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. +func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { + fs.renameMu.RUnlock() + if *ds == nil { + return + } + if len(**ds) != 0 { + fs.renameMu.Lock() + for _, d := range **ds { + d.checkDropLocked(ctx) + } + fs.renameMu.Unlock() + } + putDentrySlice(*ds) +} + +func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { + if *ds == nil { + fs.renameMu.Unlock() + return + } + for _, d := range **ds { + d.checkDropLocked(ctx) + } + fs.renameMu.Unlock() + putDentrySlice(*ds) +} + +// stepLocked resolves rp.Component() to an existing file, starting from the +// given directory. +// +// Dentries which may have a reference count of zero, and which therefore +// should be dropped once traversal is complete, are appended to ds. +// +// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// !rp.Done(). +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { + if !d.isDir() { + return nil, syserror.ENOTDIR + } + + if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + +afterSymlink: + name := rp.Component() + if name == "." { + rp.Advance() + return d, nil + } + if name == ".." { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { + return nil, err + } else if isRoot || d.parent == nil { + rp.Advance() + return d, nil + } + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, err + } + rp.Advance() + return d.parent, nil + } + child, err := fs.getChildLocked(ctx, d, name, ds) + if err != nil { + return nil, err + } + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { + return nil, err + } + if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { + target, err := child.readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + goto afterSymlink // don't check the current directory again + } + rp.Advance() + return child, nil +} + +// verifyChild verifies the root hash of child against the already verified +// root hash of the parent to ensure the child is expected. verifyChild +// triggers a sentry panic if unexpected modifications to the file system are +// detected. In noCrashOnVerificationFailure mode it returns a syserror +// instead. +// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// TODO(b/166474175): Investigate all possible errors returned in this +// function, and make sure we differentiate all errors that indicate unexpected +// modifications to the file system from the ones that are not harmful. +func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { + vfsObj := fs.vfsfs.VirtualFilesystem() + + // Get the path to the child dentry. This is only used to provide path + // information in failure case. + childPath, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + verityMu.RLock() + defer verityMu.RUnlock() + // Read the offset of the child from the extended attributes of the + // corresponding Merkle tree file. + // This is the offset of the root hash for child in its parent's Merkle + // tree file. + off, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{ + Root: child.lowerMerkleVD, + Start: child.lowerMerkleVD, + }, &vfs.GetXattrOptions{ + Name: merkleOffsetInParentXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the file or the xattr does not + // exist, it indicates unexpected modifications to the file system. + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + } + if err != nil { + return nil, err + } + // The offset xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + offset, err := strconv.Atoi(off) + if err != nil { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + } + + // Open parent Merkle tree file to read and verify child's root hash. + parentMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerMerkleVD, + Start: parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + + // The parent Merkle tree file should have been created. If it's + // missing, it indicates an unexpected modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + } + if err != nil { + return nil, err + } + + // dataSize is the size of raw data for the Merkle tree. For a file, + // dataSize is the size of the whole file. For a directory, dataSize is + // the size of all its children's root hashes. + dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the file or the xattr does not + // exist, it indicates unexpected modifications to the file system. + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + } + if err != nil { + return nil, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + parentSize, err := strconv.Atoi(dataSize) + if err != nil { + return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + } + + fdReader := vfs.FileReadWriteSeeker{ + FD: parentMerkleFD, + Ctx: ctx, + } + + // Since we are verifying against a directory Merkle tree, buf should + // contain the root hash of the children in the parent Merkle tree when + // Verify returns with success. + var buf bytes.Buffer + if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { + return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + } + + // Cache child root hash when it's verified the first time. + if len(child.rootHash) == 0 { + child.rootHash = buf.Bytes() + } + return child, nil +} + +// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { + if child, ok := parent.children[name]; ok { + // If enabling verification on files/directories is not allowed + // during runtime, all cached children are already verified. If + // runtime enable is allowed and the parent directory is + // enabled, we should verify the child root hash here because + // it may be cached before enabled. + if fs.allowRuntimeEnable && len(parent.rootHash) != 0 { + if _, err := fs.verifyChild(ctx, parent, child); err != nil { + return nil, err + } + } + return child, nil + } + child, err := fs.lookupAndVerifyLocked(ctx, parent, name) + if err != nil { + return nil, err + } + if parent.children == nil { + parent.children = make(map[string]*dentry) + } + parent.children[name] = child + // child's refcount is initially 0, so it may be dropped after traversal. + *ds = appendDentry(*ds, child) + return child, nil +} + +// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { + vfsObj := fs.vfsfs.VirtualFilesystem() + + childFilename := fspath.Parse(name) + childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: childFilename, + }, &vfs.GetDentryOptions{}) + + // We will handle ENOENT separately, as it may indicate unexpected + // modifications to the file system, and may cause a sentry panic. + if childErr != nil && childErr != syserror.ENOENT { + return nil, childErr + } + + // The dentry needs to be cleaned up if any error occurs. IncRef will be + // called if a verity child dentry is successfully created. + if childErr == nil { + defer childVD.DecRef(ctx) + } + + childMerkleFilename := merklePrefix + name + childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(childMerkleFilename), + }, &vfs.GetDentryOptions{}) + + // We will handle ENOENT separately, as it may indicate unexpected + // modifications to the file system, and may cause a sentry panic. + if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { + return nil, childMerkleErr + } + + // The dentry needs to be cleaned up if any error occurs. IncRef will be + // called if a verity child dentry is successfully created. + if childMerkleErr == nil { + defer childMerkleVD.DecRef(ctx) + } + + // Get the path to the parent dentry. This is only used to provide path + // information in failure case. + parentPath, err := vfsObj.PathnameWithDeleted(ctx, parent.fs.rootDentry.lowerVD, parent.lowerVD) + if err != nil { + return nil, err + } + + // TODO(b/166474175): Investigate all possible errors of childErr and + // childMerkleErr, and make sure we differentiate all errors that + // indicate unexpected modifications to the file system from the ones + // that are not harmful. + if childErr == syserror.ENOENT && childMerkleErr == nil { + // Failed to get child file/directory dentry. However the + // corresponding Merkle tree is found. This indicates an + // unexpected modification to the file system that + // removed/renamed the child. + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + } else if childErr == nil && childMerkleErr == syserror.ENOENT { + // If in allowRuntimeEnable mode, and the Merkle tree file is + // not created yet, we create an empty Merkle tree file, so that + // if the file is enabled through ioctl, we have the Merkle tree + // file open and ready to use. + // This may cause empty and unused Merkle tree files in + // allowRuntimeEnable mode, if they are never enabled. This + // does not affect verification, as we rely on cached root hash + // to decide whether to perform verification, not the existence + // of the Merkle tree file. Also, those Merkle tree files are + // always hidden and cannot be accessed by verity fs users. + if fs.allowRuntimeEnable { + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(childMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(childMerkleFilename), + }, &vfs.GetDentryOptions{}) + if err != nil { + return nil, err + } + } else { + // If runtime enable is not allowed. This indicates an + // unexpected modification to the file system that + // removed/renamed the Merkle tree file. + return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + } + } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { + // Both the child and the corresponding Merkle tree are missing. + // This could be an unexpected modification or due to incorrect + // parameter. + // TODO(b/167752508): Investigate possible ways to differentiate + // cases that both files are deleted from cases that they never + // exist in the file system. + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + } + + mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) + stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ + Root: childVD, + Start: childVD, + }, &vfs.StatOptions{ + Mask: mask, + }) + if err != nil { + return nil, err + } + + child := fs.newDentry() + child.lowerVD = childVD + child.lowerMerkleVD = childMerkleVD + + // Increase the reference for both childVD and childMerkleVD as they are + // held by child. If this function fails and the child is destroyed, the + // references will be decreased in destroyLocked. + childVD.IncRef() + childMerkleVD.IncRef() + + parent.IncRef() + child.parent = parent + child.name = name + + // TODO(b/162788573): Verify child metadata. + child.mode = uint32(stat.Mode) + child.uid = stat.UID + child.gid = stat.GID + + // Verify child root hash. This should always be performed unless in + // allowRuntimeEnable mode and the parent directory hasn't been enabled + // yet. + if !(fs.allowRuntimeEnable && len(parent.rootHash) == 0) { + if _, err := fs.verifyChild(ctx, parent, child); err != nil { + child.destroyLocked(ctx) + return nil, err + } + } + + return child, nil +} + +// walkParentDirLocked resolves all but the last path component of rp to an +// existing directory, starting from the given directory (which is usually +// rp.Start().Impl().(*dentry)). It does not check that the returned directory +// is searchable by the provider of rp. +// +// Preconditions: fs.renameMu must be locked. !rp.Done(). +func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { + for !rp.Final() { + d.dirMu.Lock() + next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + d = next + } + if !d.isDir() { + return nil, syserror.ENOTDIR + } + return d, nil +} + +// resolveLocked resolves rp to an existing file. +// +// Preconditions: fs.renameMu must be locked. +func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { + d := rp.Start().Impl().(*dentry) + for !rp.Done() { + d.dirMu.Lock() + next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + d = next + } + if rp.MustBeDir() && !d.isDir() { + return nil, syserror.ENOTDIR + } + return d, nil +} + +// AccessAt implements vfs.Filesystem.Impl.AccessAt. +func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { + // Verity file system is read-only. + if ats&vfs.MayWrite != 0 { + return syserror.EROFS + } + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return err + } + return d.checkPermissions(creds, ats) +} + +// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. +func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return nil, err + } + if opts.CheckSearchable { + if !d.isDir() { + return nil, syserror.ENOTDIR + } + if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + } + d.IncRef() + return &d.vfsd, nil +} + +// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. +func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + start := rp.Start().Impl().(*dentry) + d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) + if err != nil { + return nil, err + } + d.IncRef() + return &d.vfsd, nil +} + +// LinkAt implements vfs.FilesystemImpl.LinkAt. +func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// OpenAt implements vfs.FilesystemImpl.OpenAt. +func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // Verity fs is read-only. + if opts.Flags&(linux.O_WRONLY|linux.O_CREAT) != 0 { + return nil, syserror.EROFS + } + + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + + start := rp.Start().Impl().(*dentry) + if rp.Done() { + return start.openLocked(ctx, rp, &opts) + } + +afterTrailingSymlink: + parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) + if err != nil { + return nil, err + } + + // Check for search permission in the parent directory. + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + + // Open existing child or follow symlink. + parent.dirMu.Lock() + child, err := fs.stepLocked(ctx, rp, parent, false /*mayFollowSymlinks*/, &ds) + parent.dirMu.Unlock() + if err != nil { + return nil, err + } + if child.isSymlink() && rp.ShouldFollowSymlink() { + target, err := child.readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + start = parent + goto afterTrailingSymlink + } + return child.openLocked(ctx, rp, &opts) +} + +// Preconditions: fs.renameMu must be locked. +func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { + // Users should not open the Merkle tree files. Those are for verity fs + // use only. + if strings.Contains(d.name, merklePrefix) { + return nil, syserror.EPERM + } + ats := vfs.AccessTypesForOpenFlags(opts) + if err := d.checkPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + + // Verity fs is read-only. + if ats&vfs.MayWrite != 0 { + return nil, syserror.EROFS + } + + // Get the path to the target file. This is only used to provide path + // information in failure case. + path, err := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD) + if err != nil { + return nil, err + } + + // Open the file in the underlying file system. + lowerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }, opts) + + // The file should exist, as we succeeded in finding its dentry. If it's + // missing, it indicates an unexpected modification to the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + } + return nil, err + } + + // lowerFD needs to be cleaned up if any error occurs. IncRef will be + // called if a verity FD is successfully created. + defer lowerFD.DecRef(ctx) + + // Open the Merkle tree file corresponding to the current file/directory + // to be used later for verifying Read/Walk. + merkleReader, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + + // The Merkle tree file should exist, as we succeeded in finding its + // dentry. If it's missing, it indicates an unexpected modification to + // the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + + // merkleReader needs to be cleaned up if any error occurs. IncRef will + // be called if a verity FD is successfully created. + defer merkleReader.DecRef(ctx) + + lowerFlags := lowerFD.StatusFlags() + lowerFDOpts := lowerFD.Options() + var merkleWriter *vfs.FileDescription + var parentMerkleWriter *vfs.FileDescription + + // Only open the Merkle tree files for write if in allowRuntimeEnable + // mode. + if d.fs.allowRuntimeEnable { + merkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + // merkleWriter is cleaned up if any error occurs. IncRef will + // be called if a verity FD is created successfully. + defer merkleWriter.DecRef(ctx) + + if d.parent != nil { + parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.parent.lowerMerkleVD, + Start: d.parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + } + return nil, err + } + // parentMerkleWriter is cleaned up if any error occurs. IncRef + // will be called if a verity FD is created successfully. + defer parentMerkleWriter.DecRef(ctx) + } + } + + fd := &fileDescription{ + d: d, + lowerFD: lowerFD, + merkleReader: merkleReader, + merkleWriter: merkleWriter, + parentMerkleWriter: parentMerkleWriter, + isDir: d.isDir(), + } + + if err := fd.vfsfd.Init(fd, lowerFlags, rp.Mount(), &d.vfsd, &lowerFDOpts); err != nil { + return nil, err + } + lowerFD.IncRef() + merkleReader.IncRef() + if merkleWriter != nil { + merkleWriter.IncRef() + } + if parentMerkleWriter != nil { + parentMerkleWriter.IncRef() + } + return &fd.vfsfd, err +} + +// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. +func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return "", err + } + //TODO(b/162787271): Provide integrity check for ReadlinkAt. + return fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }) +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// StatAt implements vfs.FilesystemImpl.StatAt. +func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return linux.Statx{}, err + } + + var stat linux.Statx + stat, err = fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }, &opts) + if err != nil { + return linux.Statx{}, err + } + return stat, nil +} + +// StatFSAt implements vfs.FilesystemImpl.StatFSAt. +func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { + // TODO(b/159261227): Implement StatFSAt. + return linux.Statfs{}, nil +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. +func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + if _, err := fs.resolveLocked(ctx, rp, &ds); err != nil { + return nil, err + } + return nil, syserror.ECONNREFUSED +} + +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return nil, err + } + lowerVD := d.lowerVD + return fs.vfsfs.VirtualFilesystem().ListXattrAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, size) +} + +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return "", err + } + lowerVD := d.lowerVD + return fs.vfsfs.VirtualFilesystem().GetXattrAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &opts) +} + +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + // Verity file system is read-only. + return syserror.EROFS +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.renameMu.RLock() + defer fs.renameMu.RUnlock() + mnt := vd.Mount() + d := vd.Dentry().Impl().(*dentry) + for { + if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { + return vfs.PrependPathAtVFSRootError{} + } + if &d.vfsd == mnt.Root() { + return nil + } + if d.parent == nil { + return vfs.PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go new file mode 100644 index 000000000..fc5eabbca --- /dev/null +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -0,0 +1,743 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package verity provides a filesystem implementation that is a wrapper of +// another file system. +// The verity file system provides integrity check for the underlying file +// system by providing verification for path traversals and each read. +// The verity file system is read-only, except for one case: when +// allowRuntimeEnable is true, additional Merkle files can be generated using +// the FS_IOC_ENABLE_VERITY ioctl. +package verity + +import ( + "fmt" + "strconv" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Name is the default filesystem name. +const Name = "verity" + +// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle +// tree file for "/foo" is "/.merkle.verity.foo". +const merklePrefix = ".merkle.verity." + +// merkleoffsetInParentXattr is the extended attribute name specifying the +// offset of child root hash in its parent's Merkle tree. +const merkleOffsetInParentXattr = "user.merkle.offset" + +// merkleSizeXattr is the extended attribute name specifying the size of data +// hashed by the corresponding Merkle tree. For a file, it's the size of the +// whole file. For a directory, it's the size of all its children's root hashes. +const merkleSizeXattr = "user.merkle.size" + +// sizeOfStringInt32 is the size for a 32 bit integer stored as string in +// extended attributes. The maximum value of a 32 bit integer is 10 digits. +const sizeOfStringInt32 = 10 + +// noCrashOnVerificationFailure indicates whether the sandbox should panic +// whenever verification fails. If true, an error is returned instead of +// panicking. This should only be set for tests. +// TOOD(b/165661693): Decide whether to panic or return error based on this +// flag. +var noCrashOnVerificationFailure bool + +// verityMu synchronizes enabling verity files, protects files or directories +// from being enabled by different threads simultaneously. It also ensures that +// verity does not access files that are being enabled. +var verityMu sync.RWMutex + +// FilesystemType implements vfs.FilesystemType. +// +// +stateify savable +type FilesystemType struct{} + +// filesystem implements vfs.FilesystemImpl. +// +// +stateify savable +type filesystem struct { + vfsfs vfs.Filesystem + + // creds is a copy of the filesystem's creator's credentials, which are + // used for accesses to the underlying file system. creds is immutable. + creds *auth.Credentials + + // allowRuntimeEnable is true if using ioctl with FS_IOC_ENABLE_VERITY + // to build Merkle trees in the verity file system is allowed. If this + // is false, no new Merkle trees can be built, and only the files that + // had Merkle trees before startup (e.g. from a host filesystem mounted + // with gofer fs) can be verified. + allowRuntimeEnable bool + + // lowerMount is the underlying file system mount. + lowerMount *vfs.Mount + + // rootDentry is the mount root Dentry for this file system, which + // stores the root hash of the whole file system in bytes. + rootDentry *dentry + + // renameMu synchronizes renaming with non-renaming operations in order + // to ensure consistent lock ordering between dentry.dirMu in different + // dentries. + renameMu sync.RWMutex `state:"nosave"` +} + +// InternalFilesystemOptions may be passed as +// vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem. +// +// +stateify savable +type InternalFilesystemOptions struct { + // RootMerkleFileName is the name of the verity root Merkle tree file. + RootMerkleFileName string + + // LowerName is the name of the filesystem wrapped by verity fs. + LowerName string + + // RootHash is the root hash of the overall verity file system. + RootHash []byte + + // AllowRuntimeEnable specifies whether the verity file system allows + // enabling verification for files (i.e. building Merkle trees) during + // runtime. + AllowRuntimeEnable bool + + // LowerGetFSOptions is the file system option for the lower layer file + // system wrapped by verity file system. + LowerGetFSOptions vfs.GetFilesystemOptions + + // NoCrashOnVerificationFailure indicates whether the sandbox should + // panic whenever verification fails. If true, an error is returned + // instead of panicking. This should only be set for tests. + NoCrashOnVerificationFailure bool +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// alertIntegrityViolation alerts a violation of integrity, which usually means +// unexpected modification to the file system is detected. In +// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. +func alertIntegrityViolation(err error, msg string) error { + if noCrashOnVerificationFailure { + return err + } + panic(msg) +} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + iopts, ok := opts.InternalData.(InternalFilesystemOptions) + if !ok { + ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") + return nil, nil, syserror.EINVAL + } + noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + + // Mount the lower file system. The lower file system is wrapped inside + // verity, and should not be exposed or connected. + mopts := &vfs.MountOptions{ + GetFilesystemOptions: iopts.LowerGetFSOptions, + InternalMount: true, + } + mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) + if err != nil { + return nil, nil, err + } + + fs := &filesystem{ + creds: creds.Fork(), + lowerMount: mnt, + allowRuntimeEnable: iopts.AllowRuntimeEnable, + } + fs.vfsfs.Init(vfsObj, &fstype, fs) + + // Construct the root dentry. + d := fs.newDentry() + d.refs = 1 + lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root()) + lowerVD.IncRef() + d.lowerVD = lowerVD + + rootMerkleName := merklePrefix + iopts.RootMerkleFileName + + lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + Path: fspath.Parse(rootMerkleName), + }, &vfs.GetDentryOptions{}) + + // If runtime enable is allowed, the root merkle tree may be absent. We + // should create the tree file. + if err == syserror.ENOENT && fs.allowRuntimeEnable { + lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + Path: fspath.Parse(rootMerkleName), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, err + } + lowerMerkleFD.DecRef(ctx) + lowerMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + Path: fspath.Parse(rootMerkleName), + }, &vfs.GetDentryOptions{}) + if err != nil { + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, err + } + } else if err != nil { + // Failed to get dentry for the root Merkle file. This + // indicates an unexpected modification that removed/renamed + // the root Merkle file, or it's never generated. + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + } + d.lowerMerkleVD = lowerMerkleVD + + // Get metadata from the underlying file system. + const statMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.StatOptions{ + Mask: statMask, + }) + if err != nil { + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, err + } + + // TODO(b/162788573): Verify Metadata. + d.mode = uint32(stat.Mode) + d.uid = stat.UID + d.gid = stat.GID + + d.rootHash = make([]byte, len(iopts.RootHash)) + copy(d.rootHash, iopts.RootHash) + d.vfsd.Init(d) + + fs.rootDentry = d + + return &fs.vfsfs, &d.vfsd, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release(ctx context.Context) { + fs.lowerMount.DecRef(ctx) +} + +// dentry implements vfs.DentryImpl. +// +// +stateify savable +type dentry struct { + vfsd vfs.Dentry + + refs int64 + + // fs is the owning filesystem. fs is immutable. + fs *filesystem + + // mode, uid and gid are the file mode, owner, and group of the file in + // the underlying file system. + mode uint32 + uid uint32 + gid uint32 + + // parent is the dentry corresponding to this dentry's parent directory. + // name is this dentry's name in parent. If this dentry is a filesystem + // root, parent is nil and name is the empty string. parent and name are + // protected by fs.renameMu. + parent *dentry + name string + + // If this dentry represents a directory, children maps the names of + // children for which dentries have been instantiated to those dentries, + // and dirents (if not nil) is a cache of dirents as returned by + // directoryFDs representing this directory. children is protected by + // dirMu. + dirMu sync.Mutex `state:"nosave"` + children map[string]*dentry + + // lowerVD is the VirtualDentry in the underlying file system. + lowerVD vfs.VirtualDentry + + // lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree + // in the underlying file system. + lowerMerkleVD vfs.VirtualDentry + + // rootHash is the rootHash for the current file or directory. + rootHash []byte +} + +// newDentry creates a new dentry representing the given verity file. The +// dentry initially has no references; it is the caller's responsibility to set +// the dentry's reference count and/or call dentry.destroy() as appropriate. +// The dentry is initially invalid in that it contains no underlying dentry; +// the caller is responsible for setting them. +func (fs *filesystem) newDentry() *dentry { + d := &dentry{ + fs: fs, + } + d.vfsd.Init(d) + return d +} + +// IncRef implements vfs.DentryImpl.IncRef. +func (d *dentry) IncRef() { + atomic.AddInt64(&d.refs, 1) +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *dentry) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&d.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + return true + } + } +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *dentry) DecRef(ctx context.Context) { + if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + d.fs.renameMu.Lock() + d.checkDropLocked(ctx) + d.fs.renameMu.Unlock() + } else if refs < 0 { + panic("verity.dentry.DecRef() called without holding a reference") + } +} + +// checkDropLocked should be called after d's reference count becomes 0 or it +// becomes deleted. +func (d *dentry) checkDropLocked(ctx context.Context) { + // Dentries with a positive reference count must be retained. Dentries + // with a negative reference count have already been destroyed. + if atomic.LoadInt64(&d.refs) != 0 { + return + } + // Refs is still zero; destroy it. + d.destroyLocked(ctx) + return +} + +// destroyLocked destroys the dentry. +// +// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. +func (d *dentry) destroyLocked(ctx context.Context) { + switch atomic.LoadInt64(&d.refs) { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("verity.dentry.destroyLocked() called on already destroyed dentry") + default: + panic("verity.dentry.destroyLocked() called with references on the dentry") + } + + if d.lowerVD.Ok() { + d.lowerVD.DecRef(ctx) + } + + if d.lowerMerkleVD.Ok() { + d.lowerMerkleVD.DecRef(ctx) + } + + if d.parent != nil { + d.parent.dirMu.Lock() + if !d.vfsd.IsDead() { + delete(d.parent.children, d.name) + } + d.parent.dirMu.Unlock() + if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { + d.parent.checkDropLocked(ctx) + } else if refs < 0 { + panic("verity.dentry.DecRef() called without holding a reference") + } + } +} + +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { + //TODO(b/159261227): Implement InotifyWithParent. +} + +// Watches implements vfs.DentryImpl.Watches. +func (d *dentry) Watches() *vfs.Watches { + //TODO(b/159261227): Implement Watches. + return nil +} + +// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. +func (d *dentry) OnZeroWatches(context.Context) { + //TODO(b/159261227): Implement OnZeroWatches. +} + +func (d *dentry) isSymlink() bool { + return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK +} + +func (d *dentry) isDir() bool { + return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR +} + +func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) +} + +func (d *dentry) readlink(ctx context.Context) (string, error) { + return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }) +} + +// FileDescription implements vfs.FileDescriptionImpl for verity fds. +// FileDescription is a wrapper of the underlying lowerFD, with support to build +// Merkle trees through the Linux fs-verity API to verify contents read from +// lowerFD. +// +// +stateify savable +type fileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + // d is the corresponding dentry to the fileDescription. + d *dentry + + // isDir specifies whehter the fileDescription points to a directory. + isDir bool + + // lowerFD is the FileDescription corresponding to the file in the + // underlying file system. + lowerFD *vfs.FileDescription + + // merkleReader is the read-only FileDescription corresponding to the + // Merkle tree file in the underlying file system. + merkleReader *vfs.FileDescription + + // merkleWriter is the FileDescription corresponding to the Merkle tree + // file in the underlying file system for writing. This should only be + // used when allowRuntimeEnable is set to true. + merkleWriter *vfs.FileDescription + + // parentMerkleWriter is the FileDescription of the Merkle tree for the + // directory that contains the current file/directory. This is only used + // if allowRuntimeEnable is set to true. + parentMerkleWriter *vfs.FileDescription +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *fileDescription) Release(ctx context.Context) { + fd.lowerFD.DecRef(ctx) + fd.merkleReader.DecRef(ctx) + if fd.merkleWriter != nil { + fd.merkleWriter.DecRef(ctx) + } + if fd.parentMerkleWriter != nil { + fd.parentMerkleWriter.DecRef(ctx) + } +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + // TODO(b/162788573): Add integrity check for metadata. + stat, err := fd.lowerFD.Stat(ctx, opts) + if err != nil { + return linux.Statx{}, err + } + return stat, nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + // Verity files are read-only. + return syserror.EPERM +} + +// generateMerkle generates a Merkle tree file for fd. If fd points to a file +// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The root +// hash of the generated Merkle tree and the data size is returned. +// If fd points to a regular file, the data is the content of the file. If fd +// points to a directory, the data is all root hahes of its children, written +// to the Merkle tree file. +func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) { + fdReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + merkleWriter := vfs.FileReadWriteSeeker{ + FD: fd.merkleWriter, + Ctx: ctx, + } + var rootHash []byte + var dataSize uint64 + + switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { + case linux.S_IFREG: + // For a regular file, generate a Merkle tree based on its + // content. + var err error + stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return nil, 0, err + } + dataSize = stat.Size + + rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */) + if err != nil { + return nil, 0, err + } + case linux.S_IFDIR: + // For a directory, generate a Merkle tree based on the root + // hashes of its children that has already been written to the + // Merkle tree file. + merkleStat, err := fd.merkleReader.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return nil, 0, err + } + dataSize = merkleStat.Size + + rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */) + if err != nil { + return nil, 0, err + } + default: + // TODO(b/167728857): Investigate whether and how we should + // enable other types of file. + return nil, 0, syserror.EINVAL + } + return rootHash, dataSize, nil +} + +// enableVerity enables verity features on fd by generating a Merkle tree file +// and stores its root hash in its parent directory's Merkle tree. +func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) { + if !fd.d.fs.allowRuntimeEnable { + return 0, syserror.EPERM + } + + // Lock to prevent other threads performing enable or access the file + // while it's being enabled. + verityMu.Lock() + defer verityMu.Unlock() + + // In allowRuntimeEnable mode, the underlying fd and read/write fd for + // the Merkle tree file should have all been initialized. For any file + // or directory other than the root, the parent Merkle tree file should + // have also been initialized. + if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { + return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + } + + rootHash, dataSize, err := fd.generateMerkle(ctx) + if err != nil { + return 0, err + } + + if fd.parentMerkleWriter != nil { + stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return 0, err + } + + // Write the root hash of fd to the parent directory's Merkle + // tree file, as it should be part of the parent Merkle tree + // data. parentMerkleWriter is open with O_APPEND, so it + // should write directly to the end of the file. + if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { + return 0, err + } + + // Record the offset of the root hash of fd in parent directory's + // Merkle tree file. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleOffsetInParentXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return 0, err + } + } + + // Record the size of the data being hashed for fd. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleSizeXattr, + Value: strconv.Itoa(int(dataSize)), + }); err != nil { + return 0, err + } + fd.d.rootHash = append(fd.d.rootHash, rootHash...) + return 0, nil +} + +// measureVerity returns the root hash of fd, saved in args[2]. +func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + var metadata linux.DigestMetadata + + // If allowRuntimeEnable is true, an empty fd.d.rootHash indicates that + // verity is not enabled for the file. If allowRuntimeEnable is false, + // this is an integrity violation because all files should have verity + // enabled, in which case fd.d.rootHash should be set. + if len(fd.d.rootHash) == 0 { + if fd.d.fs.allowRuntimeEnable { + return 0, syserror.ENODATA + } + return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no root hash found") + } + + // The first part of VerityDigest is the metadata. + if _, err := metadata.CopyIn(t, verityDigest); err != nil { + return 0, err + } + if metadata.DigestSize < uint16(len(fd.d.rootHash)) { + return 0, syserror.EOVERFLOW + } + + // Populate the output digest size, since DigestSize is both input and + // output. + metadata.DigestSize = uint16(len(fd.d.rootHash)) + + // First copy the metadata. + if _, err := metadata.CopyOut(t, verityDigest); err != nil { + return 0, err + } + + // Now copy the root hash bytes to the memory after metadata. + _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.rootHash) + return 0, err +} + +func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) { + f := int32(0) + + // All enabled files should store a root hash. This flag is not settable + // via FS_IOC_SETFLAGS. + if len(fd.d.rootHash) != 0 { + f |= linux.FS_VERITY_FL + } + + t := kernel.TaskFromContext(ctx) + _, err := primitive.CopyInt32Out(t, flags, f) + return 0, err +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + switch cmd := args[1].Uint(); cmd { + case linux.FS_IOC_ENABLE_VERITY: + return fd.enableVerity(ctx, uio) + case linux.FS_IOC_MEASURE_VERITY: + return fd.measureVerity(ctx, uio, args[2].Pointer()) + case linux.FS_IOC_GETFLAGS: + return fd.verityFlags(ctx, uio, args[2].Pointer()) + default: + // TODO(b/169682228): Investigate which ioctl commands should + // be allowed. + return 0, syserror.ENOSYS + } +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // No need to verify if the file is not enabled yet in + // allowRuntimeEnable mode. + if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 { + return fd.lowerFD.PRead(ctx, dst, offset, opts) + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return 0, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + dataReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */) + if err != nil { + return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err)) + } + return n, err +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go new file mode 100644 index 000000000..8bcc14131 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -0,0 +1,429 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "fmt" + "io" + "math/rand" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// rootMerkleFilename is the name of the root Merkle tree file. +const rootMerkleFilename = "root.verity" + +// maxDataSize is the maximum data size written to the file for test. +const maxDataSize = 100000 + +// newVerityRoot creates a new verity mount, and returns the root. The +// underlying file system is tmpfs. If the error is not nil, then cleanup +// should be called when the root is no longer needed. +func newVerityRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { + rand.Seed(time.Now().UnixNano()) + vfsObj := &vfs.VirtualFilesystem{} + if err := vfsObj.Init(ctx); err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) + } + + vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: InternalFilesystemOptions{ + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + AllowRuntimeEnable: true, + NoCrashOnVerificationFailure: true, + }, + }, + }) + if err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) + } + root := mntns.Root() + return vfsObj, root, func() { + root.DecRef(ctx) + mntns.DecRef(ctx) + }, nil +} + +// newFileFD creates a new file in the verity mount, and returns the FD. The FD +// points to a file that has random data generated. +func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { + creds := auth.CredentialsFromContext(ctx) + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + + // Create the file in the underlying file system. + lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, + Mode: linux.ModeRegular | mode, + }) + if err != nil { + return nil, 0, err + } + + // Generate random data to be written to the file. + dataSize := rand.Intn(maxDataSize) + 1 + data := make([]byte, dataSize) + rand.Read(data) + + // Write directly to the underlying FD, since verity FD is read-only. + n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + return nil, 0, err + } + + if n != int64(len(data)) { + return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data)) + } + + lowerFD.DecRef(ctx) + + // Now open the verity file descriptor. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular | mode, + }) + return fd, dataSize, err +} + +// corruptRandomBit randomly flips a bit in the file represented by fd. +func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { + // Flip a random bit in the underlying file. + randomPos := int64(rand.Intn(size)) + byteToModify := make([]byte, 1) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { + return fmt.Errorf("lowerFD.PRead: %v", err) + } + byteToModify[0] ^= 1 + if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil { + return fmt.Errorf("lowerFD.PWrite: %v", err) + } + return nil +} + +// TestOpen ensures that when a file is created, the corresponding Merkle tree +// file and the root Merkle tree file exist. +func TestOpen(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } +} + +// TestUntouchedFileSucceeds ensures that read from an untouched verity file +// succeeds after enabling verity for it. +func TestReadUntouchedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } +} + +// TestReopenUntouchedFileSucceeds ensures that reopen an untouched verity file +// succeeds after enabling verity for it. +func TestReopenUntouchedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } +} + +// TestModifiedFileFails ensures that read from a modified verity file fails. +func TestModifiedFileFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded with modified file") + } +} + +// TestModifiedMerkleFails ensures that read from a verity file fails if the +// corresponding Merkle tree file is modified. +func TestModifiedMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } +} + +// TestModifiedParentMerkleFails ensures that open a verity enabled file in a +// verity enabled directory fails if the hashes related to the target file in +// the parent Merkle tree file is modified. +func TestModifiedParentMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD + + parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: parentLowerMerkleVD, + Start: parentLowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + parentMerkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } +} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 07bf39fed..5bba9de0b 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -15,6 +15,7 @@ go_library( ], deps = [ "//pkg/context", + "//pkg/tcpip", "//pkg/tcpip/stack", ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 2916a0644..fbe6d6aa6 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,7 +15,10 @@ // Package inet defines semantics for IP stacks. package inet -import "gvisor.dev/gvisor/pkg/tcpip/stack" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) // Stack represents a TCP/IP stack. type Stack interface { @@ -56,6 +59,12 @@ type Stack interface { // settings. SetTCPSACKEnabled(enabled bool) error + // TCPRecovery returns the TCP loss detection algorithm. + TCPRecovery() (TCPLossRecovery, error) + + // SetTCPRecovery attempts to change TCP loss detection algorithm. + SetTCPRecovery(recovery TCPLossRecovery) error + // Statistics reports stack statistics. Statistics(stat interface{}, arg string) error @@ -74,6 +83,12 @@ type Stack interface { // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. RestoreCleanupEndpoints([]stack.TransportEndpoint) + + // Forwarding returns if packet forwarding between NICs is enabled. + Forwarding(protocol tcpip.NetworkProtocolNumber) bool + + // SetForwarding enables or disables packet forwarding between NICs. + SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error } // Interface contains information about a network interface. @@ -189,3 +204,14 @@ type StatSNMPUDP [8]uint64 // StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. type StatSNMPUDPLite [8]uint64 + +// TCPLossRecovery indicates TCP loss detection and recovery methods to use. +type TCPLossRecovery int32 + +// Loss recovery constants from include/net/tcp.h which are used to set +// /proc/sys/net/ipv4/tcp_recovery. +const ( + TCP_RACK_LOSS_DETECTION TCPLossRecovery = 1 << iota + TCP_RACK_STATIC_REO_WND + TCP_RACK_NO_DUPTHRESH +) diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index d8961fc94..1779cc6f3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,7 +14,10 @@ package inet -import "gvisor.dev/gvisor/pkg/tcpip/stack" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) // TestStack is a dummy implementation of Stack for tests. type TestStack struct { @@ -25,6 +28,8 @@ type TestStack struct { TCPRecvBufSize TCPBufferSize TCPSendBufSize TCPBufferSize TCPSACKFlag bool + Recovery TCPLossRecovery + IPForwarding bool } // NewTestStack returns a TestStack with no network interfaces. The value of @@ -91,6 +96,17 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { return nil } +// TCPRecovery implements Stack.TCPRecovery. +func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) { + return s.Recovery, nil +} + +// SetTCPRecovery implements Stack.SetTCPRecovery. +func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error { + s.Recovery = recovery + return nil +} + // Statistics implements inet.Stack.Statistics. func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil @@ -116,3 +132,14 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + return s.IPForwarding +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + s.IPForwarding = enable + return nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 25fe1921b..5de70aecb 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -69,8 +69,52 @@ go_template_instance( prefix = "socket", template = "//pkg/ilist:generic_list", types = { - "Element": "*SocketEntry", - "Linker": "*SocketEntry", + "Element": "*SocketRecordVFS1", + "Linker": "*SocketRecordVFS1", + }, +) + +go_template_instance( + name = "fd_table_refs", + out = "fd_table_refs.go", + package = "kernel", + prefix = "FDTable", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "FDTable", + }, +) + +go_template_instance( + name = "fs_context_refs", + out = "fs_context_refs.go", + package = "kernel", + prefix = "FSContext", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "FSContext", + }, +) + +go_template_instance( + name = "process_group_refs", + out = "process_group_refs.go", + package = "kernel", + prefix = "ProcessGroup", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "ProcessGroup", + }, +) + +go_template_instance( + name = "session_refs", + out = "session_refs.go", + package = "kernel", + prefix = "Session", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Session", }, ) @@ -88,9 +132,13 @@ go_library( "aio.go", "context.go", "fd_table.go", + "fd_table_refs.go", "fd_table_unsafe.go", "fs_context.go", + "fs_context_refs.go", "ipc_namespace.go", + "kcov.go", + "kcov_unsafe.go", "kernel.go", "kernel_opts.go", "kernel_state.go", @@ -99,6 +147,7 @@ go_library( "pending_signals_state.go", "posixtimer.go", "process_group_list.go", + "process_group_refs.go", "ptrace.go", "ptrace_amd64.go", "ptrace_arm64.go", @@ -106,6 +155,7 @@ go_library( "seccomp.go", "seqatomic_taskgoroutineschedinfo_unsafe.go", "session_list.go", + "session_refs.go", "sessions.go", "signal.go", "signal_handlers.go", @@ -132,6 +182,7 @@ go_library( "task_stop.go", "task_syscall.go", "task_usermem.go", + "task_work.go", "thread_group.go", "threads.go", "timekeeper.go", @@ -146,22 +197,26 @@ go_library( "gvisor.dev/gvisor/pkg/sentry/device", "gvisor.dev/gvisor/pkg/tcpip", ], + marshal = True, visibility = ["//:sandbox"], deps = [ ":uncaught_signal_go_proto", "//pkg/abi", "//pkg/abi/linux", "//pkg/amutex", - "//pkg/binary", "//pkg/bits", "//pkg/bpf", "//pkg/context", + "//pkg/coverage", "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", + "//pkg/refs_vfs2", "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", @@ -208,7 +263,6 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 920fe4329..1b9721534 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -15,28 +15,21 @@ package kernel import ( + "fmt" "syscall" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refs_vfs2" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" ) // +stateify savable type abstractEndpoint struct { - ep transport.BoundEndpoint - wr *refs.WeakRef - name string - ns *AbstractSocketNamespace -} - -// WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (e *abstractEndpoint) WeakRefGone() { - e.ns.mu.Lock() - if e.ns.endpoints[e.name].ep == e.ep { - delete(e.ns.endpoints, e.name) - } - e.ns.mu.Unlock() + ep transport.BoundEndpoint + socket refs_vfs2.RefCounter + name string + ns *AbstractSocketNamespace } // AbstractSocketNamespace is used to implement the Linux abstract socket functionality. @@ -45,7 +38,11 @@ func (e *abstractEndpoint) WeakRefGone() { type AbstractSocketNamespace struct { mu sync.Mutex `state:"nosave"` - // Keeps mapping from name to endpoint. + // Keeps a mapping from name to endpoint. AbstractSocketNamespace does not hold + // any references on any sockets that it contains; when retrieving a socket, + // TryIncRef() must be called in case the socket is concurrently being + // destroyed. It is the responsibility of the socket to remove itself from the + // abstract socket namespace when it is destroyed. endpoints map[string]abstractEndpoint } @@ -57,16 +54,16 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { } // A boundEndpoint wraps a transport.BoundEndpoint to maintain a reference on -// its backing object. +// its backing socket. type boundEndpoint struct { transport.BoundEndpoint - rc refs.RefCounter + socket refs_vfs2.RefCounter } // Release implements transport.BoundEndpoint.Release. -func (e *boundEndpoint) Release() { - e.rc.DecRef() - e.BoundEndpoint.Release() +func (e *boundEndpoint) Release(ctx context.Context) { + e.socket.DecRef(ctx) + e.BoundEndpoint.Release(ctx) } // BoundEndpoint retrieves the endpoint bound to the given name. The return @@ -80,32 +77,59 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp return nil } - rc := ep.wr.Get() - if rc == nil { - delete(a.endpoints, name) + if !ep.socket.TryIncRef() { + // The socket has reached zero references and is being destroyed. return nil } - return &boundEndpoint{ep.ep, rc} + return &boundEndpoint{ep.ep, ep.socket} } // Bind binds the given socket. // -// When the last reference managed by rc is dropped, ep may be removed from the +// When the last reference managed by socket is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() + // Check if there is already a socket (which has not yet been destroyed) bound at name. if ep, ok := a.endpoints[name]; ok { - if rc := ep.wr.Get(); rc != nil { - rc.DecRef() + if ep.socket.TryIncRef() { + ep.socket.DecRef(ctx) return syscall.EADDRINUSE } } ae := abstractEndpoint{ep: ep, name: name, ns: a} - ae.wr = refs.NewWeakRef(rc, &ae) + ae.socket = socket a.endpoints[name] = ae return nil } + +// Remove removes the specified socket at name from the abstract socket +// namespace, if it has not yet been replaced. +func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) { + a.mu.Lock() + defer a.mu.Unlock() + + ep, ok := a.endpoints[name] + if !ok { + // We never delete a map entry apart from a socket's destructor (although the + // map entry may be overwritten). Therefore, a socket should exist, even if it + // may not be the one we expect. + panic(fmt.Sprintf("expected socket to exist at '%s' in abstract socket namespace", name)) + } + + // A Bind() operation may race with callers of Remove(), e.g. in the + // following case: + // socket1 reaches zero references and begins destruction + // a.Bind("foo", ep, socket2) replaces socket1 with socket2 + // socket1's destructor calls a.Remove("foo", socket1) + // + // Therefore, we need to check that the socket at name is what we expect + // before modifying the map. + if ep.socket == socket { + delete(a.endpoints, name) + } +} diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 2bc49483a..869e49ebc 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -57,6 +57,7 @@ go_library( "id_map_set.go", "user_namespace.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index ef5723127..c08d47787 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -34,3 +34,23 @@ func CredentialsFromContext(ctx context.Context) *Credentials { } return NewAnonymousCredentials() } + +// ContextWithCredentials returns a copy of ctx carrying creds. +func ContextWithCredentials(ctx context.Context, creds *Credentials) context.Context { + return &authContext{ctx, creds} +} + +type authContext struct { + context.Context + creds *Credentials +} + +// Value implements context.Context. +func (ac *authContext) Value(key interface{}) interface{} { + switch key { + case CtxCredentials: + return ac.creds + default: + return ac.Context.Value(key) + } +} diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go index 0a58ba17c..4c32ee703 100644 --- a/pkg/sentry/kernel/auth/id.go +++ b/pkg/sentry/kernel/auth/id.go @@ -19,9 +19,13 @@ import ( ) // UID is a user ID in an unspecified user namespace. +// +// +marshal type UID uint32 // GID is a group ID in an unspecified user namespace. +// +// +marshal slice:GIDSlice type GID uint32 // In the root user namespace, user/group IDs have a 1-to-1 relationship with diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 4c0f1e41f..15519f0df 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -76,8 +76,8 @@ type pollEntry struct { // WeakRefGone implements refs.WeakRefUser.WeakRefGone. // weakReferenceGone is called when the file in the weak reference is destroyed. // The poll entry is removed in response to this. -func (p *pollEntry) WeakRefGone() { - p.epoll.RemoveEntry(p.id) +func (p *pollEntry) WeakRefGone(ctx context.Context) { + p.epoll.RemoveEntry(ctx, p.id) } // EventPoll holds all the state associated with an event poll object, that is, @@ -144,14 +144,14 @@ func NewEventPoll(ctx context.Context) *fs.File { // name matches fs/eventpoll.c:epoll_create1. dirent := fs.NewDirent(ctx, anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]")) // Release the initial dirent reference after NewFile takes a reference. - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{ files: make(map[FileIdentifier]*pollEntry), }) } // Release implements fs.FileOperations.Release. -func (e *EventPoll) Release() { +func (e *EventPoll) Release(ctx context.Context) { // We need to take the lock now because files may be attempting to // remove entries in parallel if they get destroyed. e.mu.Lock() @@ -160,7 +160,7 @@ func (e *EventPoll) Release() { // Go through all entries and clean up. for _, entry := range e.files { entry.id.File.EventUnregister(&entry.waiter) - entry.file.Drop() + entry.file.Drop(ctx) } e.files = nil } @@ -423,7 +423,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter } // RemoveEntry a files from the collection of observed files. -func (e *EventPoll) RemoveEntry(id FileIdentifier) error { +func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error { e.mu.Lock() defer e.mu.Unlock() @@ -445,7 +445,7 @@ func (e *EventPoll) RemoveEntry(id FileIdentifier) error { // Remove file from map, and drop weak reference. delete(e.files, id) - entry.file.Drop() + entry.file.Drop(ctx) return nil } diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go index 22630e9c5..55b505593 100644 --- a/pkg/sentry/kernel/epoll/epoll_test.go +++ b/pkg/sentry/kernel/epoll/epoll_test.go @@ -26,7 +26,8 @@ func TestFileDestroyed(t *testing.T) { f := filetest.NewTestFile(t) id := FileIdentifier{f, 12} - efile := NewEventPoll(contexttest.Context(t)) + ctx := contexttest.Context(t) + efile := NewEventPoll(ctx) e := efile.FileOperations.(*EventPoll) if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil { t.Fatalf("addEntry failed: %v", err) @@ -44,7 +45,7 @@ func TestFileDestroyed(t *testing.T) { } // Destroy the file. Check that we get no more events. - f.DecRef() + f.DecRef(ctx) evt = e.ReadEvents(1) if len(evt) != 0 { diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 87951adeb..bbf568dfc 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -70,7 +70,7 @@ func New(ctx context.Context, initVal uint64, semMode bool) *fs.File { // name matches fs/eventfd.c:eventfd_file_create. dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[eventfd]") // Release the initial dirent reference after NewFile takes a reference. - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{ val: initVal, semMode: semMode, @@ -106,7 +106,7 @@ func (e *EventOperations) HostFD() (int, error) { } // Release implements fs.FileOperations.Release. -func (e *EventOperations) Release() { +func (e *EventOperations) Release(context.Context) { e.mu.Lock() defer e.mu.Unlock() if e.hostfd >= 0 { diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 4b7d234a4..0ec7344cd 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -78,7 +77,8 @@ type descriptor struct { // // +stateify savable type FDTable struct { - refs.AtomicRefCount + FDTableRefs + k *Kernel // mu protects below. @@ -98,7 +98,7 @@ type FDTable struct { func (f *FDTable) saveDescriptorTable() map[int32]descriptor { m := make(map[int32]descriptor) - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { m[fd] = descriptor{ file: file, fileVFS2: fileVFS2, @@ -109,24 +109,28 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { } func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { + ctx := context.Background() f.init() // Initialize table. + f.used = 0 for fd, d := range m { - f.setAll(fd, d.file, d.fileVFS2, d.flags) + if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { + panic("VFS1 or VFS2 files set") + } // Note that we do _not_ need to acquire a extra table reference here. The // table reference will already be accounted for in the file, so we drop the // reference taken by set above. switch { case d.file != nil: - d.file.DecRef() + d.file.DecRef(ctx) case d.fileVFS2 != nil: - d.fileVFS2.DecRef() + d.fileVFS2.DecRef(ctx) } } } // drop drops the table reference. -func (f *FDTable) drop(file *fs.File) { +func (f *FDTable) drop(ctx context.Context, file *fs.File) { // Release locks. file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF}) @@ -144,14 +148,14 @@ func (f *FDTable) drop(file *fs.File) { d.InotifyEvent(ev, 0) // Drop the table reference. - file.DecRef() + file.DecRef(ctx) } // dropVFS2 drops the table reference. -func (f *FDTable) dropVFS2(file *vfs.FileDescription) { +func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the // entire file. - err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET) + err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET) if err != nil && err != syserror.ENOLCK { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) } @@ -161,10 +165,10 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) { if file.IsWritable() { ev = linux.IN_CLOSE_WRITE } - file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent) // Drop the table's reference. - file.DecRef() + file.DecRef(ctx) } // NewFDTable allocates a new FDTable that may be used by tasks in k. @@ -174,28 +178,21 @@ func (k *Kernel) NewFDTable() *FDTable { return f } -// destroy removes all of the file descriptors from the map. -func (f *FDTable) destroy() { - f.RemoveIf(func(*fs.File, *vfs.FileDescription, FDFlags) bool { - return true +// DecRef implements RefCounter.DecRef. +// +// If f reaches zero references, all of its file descriptors are removed. +func (f *FDTable) DecRef(ctx context.Context) { + f.FDTableRefs.DecRef(func() { + f.RemoveIf(ctx, func(*fs.File, *vfs.FileDescription, FDFlags) bool { + return true + }) }) } -// DecRef implements RefCounter.DecRef with destructor f.destroy. -func (f *FDTable) DecRef() { - f.DecRefWithDestructor(f.destroy) -} - -// Size returns the number of file descriptor slots currently allocated. -func (f *FDTable) Size() int { - size := atomic.LoadInt32(&f.used) - return int(size) -} - // forEach iterates over all non-nil files in sorted order. // // It is the caller's responsibility to acquire an appropriate lock. -func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { +func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { // retries tracks the number of failed TryIncRef attempts for the same FD. retries := 0 fd := int32(0) @@ -214,7 +211,7 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes continue // Race caught. } fn(fd, file, nil, flags) - file.DecRef() + file.DecRef(ctx) case fileVFS2 != nil: if !fileVFS2.TryIncRef() { retries++ @@ -224,7 +221,7 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes continue // Race caught. } fn(fd, nil, fileVFS2, flags) - fileVFS2.DecRef() + fileVFS2.DecRef(ctx) } retries = 0 fd++ @@ -234,7 +231,8 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes // String is a stringer for FDTable. func (f *FDTable) String() string { var buf strings.Builder - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + ctx := context.Background() + f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { switch { case file != nil: n, _ := file.Dirent.FullName(nil /* root */) @@ -242,7 +240,7 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() - name, err := vfsObj.PathnameWithDeleted(context.Background(), vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) + name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "<err: %v>\n", err) return @@ -277,7 +275,6 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags } f.mu.Lock() - defer f.mu.Unlock() // From f.next to find available fd. if fd < f.next { @@ -287,15 +284,25 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { - f.set(i, files[len(fds)], flags) // Set the descriptor. - fds = append(fds, i) // Record the file descriptor. + // Set the descriptor. + f.set(ctx, i, files[len(fds)], flags) + fds = append(fds, i) // Record the file descriptor. } } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { - f.set(i, nil, FDFlags{}) // Zap entry. + f.set(ctx, i, nil, FDFlags{}) + } + f.mu.Unlock() + + // Drop the reference taken by the call to f.set() that + // originally installed the file. Don't call f.drop() + // (generating inotify events, etc.) since the file should + // appear to have never been inserted into f. + for _, file := range files[:len(fds)] { + file.DecRef(ctx) } return nil, syscall.EMFILE } @@ -305,6 +312,7 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.next = fds[len(fds)-1] + 1 } + f.mu.Unlock() return fds, nil } @@ -332,7 +340,6 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes } f.mu.Lock() - defer f.mu.Unlock() // From f.next to find available fd. if fd < f.next { @@ -342,15 +349,25 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.getVFS2(i); d == nil { - f.setVFS2(i, files[len(fds)], flags) // Set the descriptor. - fds = append(fds, i) // Record the file descriptor. + // Set the descriptor. + f.setVFS2(ctx, i, files[len(fds)], flags) + fds = append(fds, i) // Record the file descriptor. } } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { - f.setVFS2(i, nil, FDFlags{}) // Zap entry. + f.setVFS2(ctx, i, nil, FDFlags{}) + } + f.mu.Unlock() + + // Drop the reference taken by the call to f.setVFS2() that + // originally installed the file. Don't call f.dropVFS2() + // (generating inotify events, etc.) since the file should + // appear to have never been inserted into f. + for _, file := range files[:len(fds)] { + file.DecRef(ctx) } return nil, syscall.EMFILE } @@ -360,6 +377,7 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes f.next = fds[len(fds)-1] + 1 } + f.mu.Unlock() return fds, nil } @@ -395,7 +413,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc } for fd < end { if d, _, _ := f.getVFS2(fd); d == nil { - f.setVFS2(fd, file, flags) + f.setVFS2(ctx, fd, file, flags) if fd == f.next { // Update next search start position. f.next = fd + 1 @@ -411,40 +429,55 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc // reference for that FD, the ref count for that existing reference is // decremented. func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error { - return f.newFDAt(ctx, fd, file, nil, flags) + df, _, err := f.newFDAt(ctx, fd, file, nil, flags) + if err != nil { + return err + } + if df != nil { + f.drop(ctx, df) + } + return nil } // NewFDAtVFS2 sets the file reference for the given FD. If there is an active // reference for that FD, the ref count for that existing reference is // decremented. func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error { - return f.newFDAt(ctx, fd, nil, file, flags) + _, dfVFS2, err := f.newFDAt(ctx, fd, nil, file, flags) + if err != nil { + return err + } + if dfVFS2 != nil { + f.dropVFS2(ctx, dfVFS2) + } + return nil } -func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error { +func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription, error) { if fd < 0 { // Don't accept negative FDs. - return syscall.EBADF + return nil, nil, syscall.EBADF } // Check the limit for the provided file. if limitSet := limits.FromContext(ctx); limitSet != nil { if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur { - return syscall.EMFILE + return nil, nil, syscall.EMFILE } } // Install the entry. f.mu.Lock() defer f.mu.Unlock() - f.setAll(fd, file, fileVFS2, flags) - return nil + + df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags) + return df, dfVFS2, nil } // SetFlags sets the flags for the given file descriptor. // // True is returned iff flags were changed. -func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { +func (f *FDTable) SetFlags(ctx context.Context, fd int32, flags FDFlags) error { if fd < 0 { // Don't accept negative FDs. return syscall.EBADF @@ -460,14 +493,14 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { } // Update the flags. - f.set(fd, file, flags) + f.set(ctx, fd, file, flags) return nil } // SetFlagsVFS2 sets the flags for the given file descriptor. // // True is returned iff flags were changed. -func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { +func (f *FDTable) SetFlagsVFS2(ctx context.Context, fd int32, flags FDFlags) error { if fd < 0 { // Don't accept negative FDs. return syscall.EBADF @@ -483,7 +516,7 @@ func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { } // Update the flags. - f.setVFS2(fd, file, flags) + f.setVFS2(ctx, fd, file, flags) return nil } @@ -541,50 +574,23 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) { // // Precondition: The caller must be running on the task goroutine, or Task.mu // must be locked. -func (f *FDTable) GetFDs() []int32 { +func (f *FDTable) GetFDs(ctx context.Context) []int32 { fds := make([]int32, 0, int(atomic.LoadInt32(&f.used))) - f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { + f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { fds = append(fds, fd) }) return fds } -// GetRefs returns a stable slice of references to all files and bumps the -// reference count on each. The caller must use DecRef on each reference when -// they're done using the slice. -func (f *FDTable) GetRefs() []*fs.File { - files := make([]*fs.File, 0, f.Size()) - f.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - file.IncRef() // Acquire a reference for caller. - files = append(files, file) - }) - return files -} - -// GetRefsVFS2 returns a stable slice of references to all files and bumps the -// reference count on each. The caller must use DecRef on each reference when -// they're done using the slice. -func (f *FDTable) GetRefsVFS2() []*vfs.FileDescription { - files := make([]*vfs.FileDescription, 0, f.Size()) - f.forEach(func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) { - file.IncRef() // Acquire a reference for caller. - files = append(files, file) - }) - return files -} - // Fork returns an independent FDTable. -func (f *FDTable) Fork() *FDTable { +func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. - switch { - case file != nil: - clone.set(fd, file, flags) - case fileVFS2 != nil: - clone.setVFS2(fd, fileVFS2, flags) + if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil { + panic("VFS1 or VFS2 files set") } }) return clone @@ -593,13 +599,12 @@ func (f *FDTable) Fork() *FDTable { // Remove removes an FD from and returns a non-file iff successful. // // N.B. Callers are required to use DecRef when they are done. -func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { +func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDescription) { if fd < 0 { return nil, nil } f.mu.Lock() - defer f.mu.Unlock() // Update current available position. if fd < f.next { @@ -615,24 +620,51 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { case orig2 != nil: orig2.IncRef() } + if orig != nil || orig2 != nil { - f.setAll(fd, nil, nil, FDFlags{}) // Zap entry. + orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry. + } + f.mu.Unlock() + + if orig != nil { + f.drop(ctx, orig) } + if orig2 != nil { + f.dropVFS2(ctx, orig2) + } + return orig, orig2 } // RemoveIf removes all FDs where cond is true. -func (f *FDTable) RemoveIf(cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) { - f.mu.Lock() - defer f.mu.Unlock() +func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) { + // TODO(gvisor.dev/issue/1624): Remove fs.File slice. + var files []*fs.File + var filesVFS2 []*vfs.FileDescription - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + f.mu.Lock() + f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { - f.set(fd, nil, FDFlags{}) // Clear from table. + df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table. + if df != nil { + files = append(files, df) + } + if dfVFS2 != nil { + filesVFS2 = append(filesVFS2, dfVFS2) + } // Update current available position. if fd < f.next { f.next = fd } } }) + f.mu.Unlock() + + for _, file := range files { + f.drop(ctx, file) + } + + for _, file := range filesVFS2 { + f.dropVFS2(ctx, file) + } } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 29f95a2c4..bf5460083 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -72,7 +72,7 @@ func TestFDTableMany(t *testing.T) { } i := int32(2) - fdTable.Remove(i) + fdTable.Remove(ctx, i) if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) } @@ -93,7 +93,7 @@ func TestFDTableOverLimit(t *testing.T) { t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) } else { for _, fd := range fds { - fdTable.Remove(fd) + fdTable.Remove(ctx, fd) } } @@ -150,13 +150,13 @@ func TestFDTable(t *testing.T) { t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref) } - ref, _ := fdTable.Remove(1) + ref, _ := fdTable.Remove(ctx, 1) if ref == nil { t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success") } - ref.DecRef() + ref.DecRef(ctx) - if ref, _ := fdTable.Remove(1); ref != nil { + if ref, _ := fdTable.Remove(ctx, 1); ref != nil { t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") } }) @@ -191,7 +191,7 @@ func BenchmarkFDLookupAndDecRef(b *testing.B) { b.StartTimer() // Benchmark. for i := 0; i < b.N; i++ { tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef() + tf.DecRef(ctx) } }) } @@ -219,7 +219,7 @@ func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) { defer wg.Done() for i := 0; i < each; i++ { tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef() + tf.DecRef(ctx) } }() } diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 7fd97dc53..da79e6627 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -31,6 +32,8 @@ type descriptorTable struct { } // init initializes the table. +// +// TODO(gvisor.dev/1486): Enable leak check for FDTable. func (f *FDTable) init() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) @@ -76,33 +79,37 @@ func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, boo return d.file, d.fileVFS2, d.flags, true } -// set sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// CurrentMaxFDs returns the number of file descriptors that may be stored in f +// without reallocation. +func (f *FDTable) CurrentMaxFDs() int { + slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + return len(slice) +} + +// set sets an entry for VFS1, refer to setAll(). // // Precondition: mu must be held. -func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) { - f.setAll(fd, file, nil, flags) +func (f *FDTable) set(ctx context.Context, fd int32, file *fs.File, flags FDFlags) *fs.File { + dropFile, _ := f.setAll(ctx, fd, file, nil, flags) + return dropFile } -// setVFS2 sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// setVFS2 sets an entry for VFS2, refer to setAll(). // // Precondition: mu must be held. -func (f *FDTable) setVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) { - f.setAll(fd, nil, file, flags) +func (f *FDTable) setVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) *vfs.FileDescription { + _, dropFile := f.setAll(ctx, fd, nil, file, flags) + return dropFile } -// setAll sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// setAll sets the file description referred to by fd to file/fileVFS2. If +// file/fileVFS2 are non-nil, it takes a reference on them. If setAll replaces +// an existing file description, it returns it with the FDTable's reference +// transferred to the caller, which must call f.drop/dropVFS2() on the returned +// file after unlocking f.mu. // // Precondition: mu must be held. -func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { +func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription) { if file != nil && fileVFS2 != nil { panic("VFS1 and VFS2 files set") } @@ -145,25 +152,25 @@ func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, } } - // Drop the table reference. + // Adjust used. + switch { + case orig == nil && desc != nil: + atomic.AddInt32(&f.used, 1) + case orig != nil && desc == nil: + atomic.AddInt32(&f.used, -1) + } + if orig != nil { switch { case orig.file != nil: if desc == nil || desc.file != orig.file { - f.drop(orig.file) + return orig.file, nil } case orig.fileVFS2 != nil: if desc == nil || desc.fileVFS2 != orig.fileVFS2 { - f.dropVFS2(orig.fileVFS2) + return nil, orig.fileVFS2 } } } - - // Adjust used. - switch { - case orig == nil && desc != nil: - atomic.AddInt32(&f.used, 1) - case orig != nil && desc == nil: - atomic.AddInt32(&f.used, -1) - } + return nil, nil } diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index 47f78df9a..d46d1e1c1 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -17,7 +17,7 @@ package kernel import ( "fmt" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -29,7 +29,7 @@ import ( // // +stateify savable type FSContext struct { - refs.AtomicRefCount + FSContextRefs // mu protects below. mu sync.Mutex `state:"nosave"` @@ -63,7 +63,7 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext { cwd: cwd, umask: umask, } - f.EnableLeakCheck("kernel.FSContext") + f.EnableLeakCheck() return &f } @@ -76,54 +76,56 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext { cwdVFS2: cwd, umask: umask, } - f.EnableLeakCheck("kernel.FSContext") + f.EnableLeakCheck() return &f } -// destroy is the destructor for an FSContext. +// DecRef implements RefCounter.DecRef. // -// This will call DecRef on both root and cwd Dirents. If either call to -// DecRef returns an error, then it will be propagated. If both calls to -// DecRef return an error, then the one from root.DecRef will be propagated. +// When f reaches zero references, DecRef will be called on both root and cwd +// Dirents. // // Note that there may still be calls to WorkingDirectory() or RootDirectory() // (that return nil). This is because valid references may still be held via // proc files or other mechanisms. -func (f *FSContext) destroy() { - // Hold f.mu so that we don't race with RootDirectory() and - // WorkingDirectory(). - f.mu.Lock() - defer f.mu.Unlock() - - if VFS2Enabled { - f.rootVFS2.DecRef() - f.rootVFS2 = vfs.VirtualDentry{} - f.cwdVFS2.DecRef() - f.cwdVFS2 = vfs.VirtualDentry{} - } else { - f.root.DecRef() - f.root = nil - f.cwd.DecRef() - f.cwd = nil - } -} - -// DecRef implements RefCounter.DecRef with destructor f.destroy. -func (f *FSContext) DecRef() { - f.DecRefWithDestructor(f.destroy) +func (f *FSContext) DecRef(ctx context.Context) { + f.FSContextRefs.DecRef(func() { + // Hold f.mu so that we don't race with RootDirectory() and + // WorkingDirectory(). + f.mu.Lock() + defer f.mu.Unlock() + + if VFS2Enabled { + f.rootVFS2.DecRef(ctx) + f.rootVFS2 = vfs.VirtualDentry{} + f.cwdVFS2.DecRef(ctx) + f.cwdVFS2 = vfs.VirtualDentry{} + } else { + f.root.DecRef(ctx) + f.root = nil + f.cwd.DecRef(ctx) + f.cwd = nil + } + }) } // Fork forks this FSContext. // -// This is not a valid call after destroy. +// This is not a valid call after f is destroyed. func (f *FSContext) Fork() *FSContext { f.mu.Lock() defer f.mu.Unlock() if VFS2Enabled { + if !f.cwdVFS2.Ok() { + panic("FSContext.Fork() called after destroy") + } f.cwdVFS2.IncRef() f.rootVFS2.IncRef() } else { + if f.cwd == nil { + panic("FSContext.Fork() called after destroy") + } f.cwd.IncRef() f.root.IncRef() } @@ -139,8 +141,8 @@ func (f *FSContext) Fork() *FSContext { // WorkingDirectory returns the current working directory. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() @@ -151,8 +153,8 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() @@ -164,8 +166,8 @@ func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { // SetWorkingDirectory sets the current working directory. // This will take an extra reference on the Dirent. // -// This is not a valid call after destroy. -func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) { +// This is not a valid call after f is destroyed. +func (f *FSContext) SetWorkingDirectory(ctx context.Context, d *fs.Dirent) { if d == nil { panic("FSContext.SetWorkingDirectory called with nil dirent") } @@ -180,27 +182,31 @@ func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) { old := f.cwd f.cwd = d d.IncRef() - old.DecRef() + old.DecRef(ctx) } // SetWorkingDirectoryVFS2 sets the current working directory. // This will take an extra reference on the VirtualDentry. // -// This is not a valid call after destroy. -func (f *FSContext) SetWorkingDirectoryVFS2(d vfs.VirtualDentry) { +// This is not a valid call after f is destroyed. +func (f *FSContext) SetWorkingDirectoryVFS2(ctx context.Context, d vfs.VirtualDentry) { f.mu.Lock() defer f.mu.Unlock() + if !f.cwdVFS2.Ok() { + panic(fmt.Sprintf("FSContext.SetWorkingDirectoryVFS2(%v)) called after destroy", d)) + } + old := f.cwdVFS2 f.cwdVFS2 = d d.IncRef() - old.DecRef() + old.DecRef(ctx) } // RootDirectory returns the current filesystem root. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) RootDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() @@ -212,8 +218,8 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() @@ -225,8 +231,8 @@ func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { // SetRootDirectory sets the root directory. // This will take an extra reference on the Dirent. // -// This is not a valid call after free. -func (f *FSContext) SetRootDirectory(d *fs.Dirent) { +// This is not a valid call after f is destroyed. +func (f *FSContext) SetRootDirectory(ctx context.Context, d *fs.Dirent) { if d == nil { panic("FSContext.SetRootDirectory called with nil dirent") } @@ -241,13 +247,13 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) { old := f.root f.root = d d.IncRef() - old.DecRef() + old.DecRef(ctx) } // SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd. // -// This is not a valid call after free. -func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) { +// This is not a valid call after f is destroyed. +func (f *FSContext) SetRootDirectoryVFS2(ctx context.Context, vd vfs.VirtualDentry) { if !vd.Ok() { panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry") } @@ -263,7 +269,7 @@ func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) { vd.IncRef() f.rootVFS2 = vd f.mu.Unlock() - old.DecRef() + old.DecRef(ctx) } // Umask returns the current umask. diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index c5021f2db..daa2dae76 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -51,6 +51,7 @@ go_test( srcs = ["futex_test.go"], library = ":futex", deps = [ + "//pkg/context", "//pkg/sync", "//pkg/usermem", ], diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 732e66da4..e4dcc4d40 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -19,6 +19,7 @@ package futex import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -66,9 +67,9 @@ type Key struct { Offset uint64 } -func (k *Key) release() { +func (k *Key) release(t Target) { if k.MappingIdentity != nil { - k.MappingIdentity.DecRef() + k.MappingIdentity.DecRef(t) } k.Mappable = nil k.MappingIdentity = nil @@ -94,6 +95,8 @@ func (k *Key) matches(k2 *Key) bool { // Target abstracts memory accesses and keys. type Target interface { + context.Context + // SwapUint32 gives access to usermem.IO.SwapUint32. SwapUint32(addr usermem.Addr, new uint32) (uint32, error) @@ -296,7 +299,7 @@ func (b *bucket) wakeWaiterLocked(w *Waiter) { // bucket "to". // // Preconditions: b and to must be locked. -func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int { +func (b *bucket) requeueLocked(t Target, to *bucket, key, nkey *Key, n int) int { done := 0 for w := b.waiters.Front(); done < n && w != nil; { if !w.key.matches(key) { @@ -308,7 +311,7 @@ func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int { requeued := w w = w.Next() // Next iteration. b.waiters.Remove(requeued) - requeued.key.release() + requeued.key.release(t) requeued.key = nkey.clone() to.waiters.PushBack(requeued) requeued.bucket.Store(to) @@ -456,7 +459,7 @@ func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32 r := b.wakeLocked(&k, bitmask, n) b.mu.Unlock() - k.release() + k.release(t) return r, nil } @@ -465,12 +468,12 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch if err != nil { return 0, err } - defer k1.release() + defer k1.release(t) k2, err := getKey(t, naddr, private) if err != nil { return 0, err } - defer k2.release() + defer k2.release(t) b1, b2 := m.lockBuckets(&k1, &k2) defer b1.mu.Unlock() @@ -488,7 +491,7 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch done := b1.wakeLocked(&k1, ^uint32(0), nwake) // Requeue the number required. - b1.requeueLocked(b2, &k1, &k2, nreq) + b1.requeueLocked(t, b2, &k1, &k2, nreq) return done, nil } @@ -515,12 +518,12 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwak if err != nil { return 0, err } - defer k1.release() + defer k1.release(t) k2, err := getKey(t, addr2, private) if err != nil { return 0, err } - defer k2.release() + defer k2.release(t) b1, b2 := m.lockBuckets(&k1, &k2) defer b1.mu.Unlock() @@ -571,7 +574,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo // Perform our atomic check. if err := check(t, addr, val); err != nil { b.mu.Unlock() - w.key.release() + w.key.release(t) return err } @@ -585,7 +588,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo // WaitComplete must be called when a Waiter previously added by WaitPrepare is // no longer eligible to be woken. -func (m *Manager) WaitComplete(w *Waiter) { +func (m *Manager) WaitComplete(w *Waiter, t Target) { // Remove w from the bucket it's in. for { b := w.bucket.Load() @@ -617,7 +620,7 @@ func (m *Manager) WaitComplete(w *Waiter) { } // Release references held by the waiter. - w.key.release() + w.key.release(t) } // LockPI attempts to lock the futex following the Priority-inheritance futex @@ -648,13 +651,13 @@ func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, pri success, err := m.lockPILocked(w, t, addr, tid, b, try) if err != nil { - w.key.release() + w.key.release(t) b.mu.Unlock() return false, err } if success || try { // Release waiter if it's not going to be a wait. - w.key.release() + w.key.release(t) } b.mu.Unlock() return success, nil @@ -717,10 +720,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3 } } -// UnlockPI unlock the futex following the Priority-inheritance futex -// rules. The address provided must contain the caller's TID. If there are -// waiters, TID of the next waiter (FIFO) is set to the given address, and the -// waiter woken up. If there are no waiters, 0 is set to the address. +// UnlockPI unlocks the futex following the Priority-inheritance futex rules. +// The address provided must contain the caller's TID. If there are waiters, +// TID of the next waiter (FIFO) is set to the given address, and the waiter +// woken up. If there are no waiters, 0 is set to the address. func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error { k, err := getKey(t, addr, private) if err != nil { @@ -730,7 +733,7 @@ func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool err = m.unlockPILocked(t, addr, tid, b, &k) - k.release() + k.release(t) b.mu.Unlock() return err } diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index 7c5c7665b..d0128c548 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -22,6 +22,7 @@ import ( "testing" "unsafe" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -29,28 +30,33 @@ import ( // testData implements the Target interface, and allows us to // treat the address passed for futex operations as an index in // a byte slice for testing simplicity. -type testData []byte +type testData struct { + context.Context + data []byte +} const sizeofInt32 = 4 func newTestData(size uint) testData { - return make([]byte, size) + return testData{ + data: make([]byte, size), + } } func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { - val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new) + val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new) return val, nil } func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { - if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) { + if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) { return old, nil } - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil } func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) { - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil } func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) { @@ -83,7 +89,7 @@ func TestFutexWake(t *testing.T) { // Start waiting for wakeup. w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w) + defer m.WaitComplete(w, d) // Perform a wakeup. if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 { @@ -106,7 +112,7 @@ func TestFutexWakeBitmask(t *testing.T) { // Start waiting for wakeup. w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff) - defer m.WaitComplete(w) + defer m.WaitComplete(w, d) // Perform a wakeup using the wrong bitmask. if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 { @@ -141,7 +147,7 @@ func TestFutexWakeTwo(t *testing.T) { var ws [3]*Waiter for i := range ws { ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) + defer m.WaitComplete(ws[i], d) } // Perform two wakeups. @@ -174,9 +180,9 @@ func TestFutexWakeUnrelated(t *testing.T) { // Start two waiters waiting for wakeup on different addresses. w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform two wakeups on the second address. if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 { @@ -216,9 +222,9 @@ func TestWakeOpFirstNonEmpty(t *testing.T) { // Add two waiters on address 0. w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform 10 wakeups on address 0. if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 { @@ -244,9 +250,9 @@ func TestWakeOpSecondNonEmpty(t *testing.T) { // Add two waiters on address sizeofInt32. w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform 10 wakeups on address sizeofInt32 (contingent on // d.Op(0), which should succeed). @@ -273,9 +279,9 @@ func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) { // Add two waiters on address sizeofInt32. w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform 10 wakeups on address sizeofInt32 (contingent on // d.Op(1), which should fail). @@ -302,15 +308,15 @@ func TestWakeOpAllNonEmpty(t *testing.T) { // Add two waiters on address 0. w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Add two waiters on address sizeofInt32. w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) + defer m.WaitComplete(w3, d) w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) + defer m.WaitComplete(w4, d) // Perform 10 wakeups on address 0 (unconditionally), and 10 // wakeups on address sizeofInt32 (contingent on d.Op(0), which @@ -344,15 +350,15 @@ func TestWakeOpAllNonEmptyFailingOp(t *testing.T) { // Add two waiters on address 0. w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Add two waiters on address sizeofInt32. w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) + defer m.WaitComplete(w3, d) w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) + defer m.WaitComplete(w4, d) // Perform 10 wakeups on address 0 (unconditionally), and 10 // wakeups on address sizeofInt32 (contingent on d.Op(1), which @@ -388,7 +394,7 @@ func TestWakeOpSameAddress(t *testing.T) { var ws [4]*Waiter for i := range ws { ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) + defer m.WaitComplete(ws[i], d) } // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup @@ -422,7 +428,7 @@ func TestWakeOpSameAddressFailingOp(t *testing.T) { var ws [4]*Waiter for i := range ws { ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) + defer m.WaitComplete(ws[i], d) } // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup @@ -472,7 +478,7 @@ func (t *testMutex) Lock() { for { // Attempt to grab the lock. if atomic.CompareAndSwapUint32( - (*uint32)(unsafe.Pointer(&t.d[t.a])), + (*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked, testMutexLocked) { // Lock held. @@ -490,7 +496,7 @@ func (t *testMutex) Lock() { panic("WaitPrepare returned unexpected error: " + err.Error()) } <-w.C - t.m.WaitComplete(w) + t.m.WaitComplete(w, t.d) } } @@ -498,7 +504,7 @@ func (t *testMutex) Lock() { // This will notify any waiters via the futex manager. func (t *testMutex) Unlock() { // Unlock. - atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked) + atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked) // Notify all waiters. t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32) diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go new file mode 100644 index 000000000..060c056df --- /dev/null +++ b/pkg/sentry/kernel/kcov.go @@ -0,0 +1,335 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "fmt" + "io" + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/mm" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// kcovAreaSizeMax is the maximum number of uint64 entries allowed in the kcov +// area. On Linux, the maximum is INT_MAX / 8. +const kcovAreaSizeMax = 10 * 1024 * 1024 + +// Kcov provides kernel coverage data to userspace through a memory-mapped +// region, as kcov does in Linux. +// +// To give the illusion that the data is always up to date, we update the shared +// memory every time before we return to userspace. +type Kcov struct { + // mfp provides application memory. It is immutable after creation. + mfp pgalloc.MemoryFileProvider + + // mu protects all of the fields below. + mu sync.RWMutex + + // mode is the current kcov mode. + mode uint8 + + // size is the size of the mapping through which the kernel conveys coverage + // information to userspace. + size uint64 + + // owningTask is the task that currently owns coverage data on the system. The + // interface for kcov essentially requires that coverage is only going to a + // single task. Note that kcov should only generate coverage data for the + // owning task, but we currently generate global coverage. + owningTask *Task + + // count is a locally cached version of the first uint64 in the kcov data, + // which is the number of subsequent entries representing PCs. + // + // It is used with kcovInode.countBlock(), to copy in/out the first element of + // the actual data in an efficient manner, avoid boilerplate, and prevent + // accidental garbage escapes by the temporary counts. + count uint64 + + mappable *mm.SpecialMappable +} + +// NewKcov creates and returns a Kcov instance. +func (k *Kernel) NewKcov() *Kcov { + return &Kcov{ + mfp: k, + } +} + +var coveragePool = sync.Pool{ + New: func() interface{} { + return make([]byte, 0) + }, +} + +// TaskWork implements TaskWorker.TaskWork. +func (kcov *Kcov) TaskWork(t *Task) { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + if kcov.mode != linux.KCOV_MODE_TRACE_PC { + return + } + + rw := &kcovReadWriter{ + mf: kcov.mfp.MemoryFile(), + fr: kcov.mappable.FileRange(), + } + + // Read in the PC count. + if _, err := safemem.ReadFullToBlocks(rw, kcov.countBlock()); err != nil { + panic(fmt.Sprintf("Internal error reading count from kcov area: %v", err)) + } + + rw.off = 8 * (1 + kcov.count) + n := coverage.ConsumeCoverageData(&kcovIOWriter{rw}) + + // Update the pc count, based on the number of entries written. Note that if + // we reached the end of the kcov area, we may not have written everything in + // output. + kcov.count += uint64(n / 8) + rw.off = 0 + if _, err := safemem.WriteFullFromBlocks(rw, kcov.countBlock()); err != nil { + panic(fmt.Sprintf("Internal error writing count to kcov area: %v", err)) + } + + // Re-register for future work. + t.RegisterWork(kcov) +} + +// InitTrace performs the KCOV_INIT_TRACE ioctl. +func (kcov *Kcov) InitTrace(size uint64) error { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + if kcov.mode != linux.KCOV_MODE_DISABLED { + return syserror.EBUSY + } + + // To simplify all the logic around mapping, we require that the length of the + // shared region is a multiple of the system page size. + if (8*size)&(usermem.PageSize-1) != 0 { + return syserror.EINVAL + } + + // We need space for at least two uint64s to hold current position and a + // single PC. + if size < 2 || size > kcovAreaSizeMax { + return syserror.EINVAL + } + + kcov.size = size + kcov.mode = linux.KCOV_MODE_INIT + return nil +} + +// EnableTrace performs the KCOV_ENABLE_TRACE ioctl. +func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error { + t := TaskFromContext(ctx) + if t == nil { + panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine") + } + + kcov.mu.Lock() + defer kcov.mu.Unlock() + + // KCOV_ENABLE must be preceded by KCOV_INIT_TRACE and an mmap call. + if kcov.mode != linux.KCOV_MODE_INIT || kcov.mappable == nil { + return syserror.EINVAL + } + + switch traceKind { + case linux.KCOV_TRACE_PC: + kcov.mode = linux.KCOV_MODE_TRACE_PC + case linux.KCOV_TRACE_CMP: + // We do not support KCOV_MODE_TRACE_CMP. + return syserror.ENOTSUP + default: + return syserror.EINVAL + } + + if kcov.owningTask != nil && kcov.owningTask != t { + return syserror.EBUSY + } + + kcov.owningTask = t + t.SetKcov(kcov) + t.RegisterWork(kcov) + + // Clear existing coverage data; the task expects to read only coverage data + // from the time it is activated. + coverage.ClearCoverageData() + return nil +} + +// DisableTrace performs the KCOV_DISABLE_TRACE ioctl. +func (kcov *Kcov) DisableTrace(ctx context.Context) error { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + t := TaskFromContext(ctx) + if t == nil { + panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine") + } + + if t != kcov.owningTask { + return syserror.EINVAL + } + kcov.mode = linux.KCOV_MODE_INIT + kcov.owningTask = nil + kcov.mappable = nil + return nil +} + +// Clear resets the mode and clears the owning task and memory mapping for kcov. +// It is called when the fd corresponding to kcov is closed. Note that the mode +// needs to be set so that the next call to kcov.TaskWork() will exit early. +func (kcov *Kcov) Clear() { + kcov.mu.Lock() + kcov.clearLocked() + kcov.mu.Unlock() +} + +func (kcov *Kcov) clearLocked() { + kcov.mode = linux.KCOV_MODE_INIT + kcov.owningTask = nil + kcov.mappable = nil +} + +// OnTaskExit is called when the owning task exits. It is similar to +// kcov.Clear(), except the memory mapping is not cleared, so that the same +// mapping can be used in the future if kcov is enabled again by another task. +func (kcov *Kcov) OnTaskExit() { + kcov.mu.Lock() + kcov.mode = linux.KCOV_MODE_INIT + kcov.owningTask = nil + kcov.mu.Unlock() +} + +// ConfigureMMap is called by the vfs.FileDescription for this kcov instance to +// implement vfs.FileDescription.ConfigureMMap. +func (kcov *Kcov) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + if kcov.mode != linux.KCOV_MODE_INIT { + return syserror.EINVAL + } + + if kcov.mappable == nil { + // Set up the kcov area. + fr, err := kcov.mfp.MemoryFile().Allocate(kcov.size*8, usage.Anonymous) + if err != nil { + return err + } + + // Get the thread id for the mmap name. + t := TaskFromContext(ctx) + if t == nil { + panic("ThreadFromContext returned nil") + } + // For convenience, a special mappable is used here. Note that these mappings + // will look different under /proc/[pid]/maps than they do on Linux. + kcov.mappable = mm.NewSpecialMappable(fmt.Sprintf("[kcov:%d]", t.ThreadID()), kcov.mfp, fr) + } + opts.Mappable = kcov.mappable + opts.MappingIdentity = kcov.mappable + return nil +} + +// kcovReadWriter implements safemem.Reader and safemem.Writer. +type kcovReadWriter struct { + off uint64 + mf *pgalloc.MemoryFile + fr memmap.FileRange +} + +// ReadToBlocks implements safemem.Reader.ReadToBlocks. +func (rw *kcovReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + if dsts.IsEmpty() { + return 0, nil + } + + // Limit the read to the kcov range and check for overflow. + if rw.fr.Length() <= rw.off { + return 0, io.EOF + } + start := rw.fr.Start + rw.off + end := rw.fr.Start + rw.fr.Length() + if rend := start + dsts.NumBytes(); rend < end { + end = rend + } + + // Get internal mappings. + bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Read) + if err != nil { + return 0, err + } + + // Copy from internal mappings. + n, err := safemem.CopySeq(dsts, bs) + rw.off += n + return n, err +} + +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +func (rw *kcovReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + if srcs.IsEmpty() { + return 0, nil + } + + // Limit the write to the kcov area and check for overflow. + if rw.fr.Length() <= rw.off { + return 0, io.EOF + } + start := rw.fr.Start + rw.off + end := rw.fr.Start + rw.fr.Length() + if wend := start + srcs.NumBytes(); wend < end { + end = wend + } + + // Get internal mapping. + bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Write) + if err != nil { + return 0, err + } + + // Copy to internal mapping. + n, err := safemem.CopySeq(bs, srcs) + rw.off += n + return n, err +} + +// kcovIOWriter implements io.Writer as a basic wrapper over kcovReadWriter. +type kcovIOWriter struct { + rw *kcovReadWriter +} + +// Write implements io.Writer.Write. +func (w *kcovIOWriter) Write(p []byte) (int, error) { + bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p)) + n, err := safemem.WriteFullFromBlocks(w.rw, bs) + return int(n), err +} diff --git a/pkg/sentry/kernel/kcov_unsafe.go b/pkg/sentry/kernel/kcov_unsafe.go new file mode 100644 index 000000000..6f64022eb --- /dev/null +++ b/pkg/sentry/kernel/kcov_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/safemem" +) + +// countBlock provides a safemem.BlockSeq for k.count. +// +// Like k.count, the block returned is protected by k.mu. +func (k *Kcov) countBlock() safemem.BlockSeq { + return safemem.BlockSeqOf(safemem.BlockFromSafePointer(unsafe.Pointer(&k.count), int(unsafe.Sizeof(k.count)))) +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 2177b785a..d6c21adb7 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -81,6 +81,10 @@ import ( // easy access everywhere. To be removed once VFS2 becomes the default. var VFS2Enabled = false +// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow +// easy access everywhere. To be removed once FUSE is completed. +var FUSEEnabled = false + // Kernel represents an emulated Linux kernel. It must be initialized by calling // Init() or LoadFrom(). // @@ -216,13 +220,18 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // sockets is the list of all network sockets the system. Protected by - // extMu. + // sockets is the list of all network sockets in the system. + // Protected by extMu. + // TODO(gvisor.dev/issue/1624): Only used by VFS1. sockets socketList - // nextSocketEntry is the next entry number to use in sockets. Protected + // socketsVFS2 records all network sockets in the system. Protected by + // extMu. + socketsVFS2 map[*vfs.FileDescription]*SocketRecord + + // nextSocketRecord is the next entry number to use in sockets. Protected // by extMu. - nextSocketEntry uint64 + nextSocketRecord uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -244,7 +253,7 @@ type Kernel struct { // SpecialOpts contains special kernel options. SpecialOpts - // VFS keeps the filesystem state used across the kernel. + // vfs keeps the filesystem state used across the kernel. vfs vfs.VirtualFilesystem // hostMount is the Mount used for file descriptors that were imported @@ -372,7 +381,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.netlinkPorts = port.New() if VFS2Enabled { - if err := k.vfs.Init(); err != nil { + ctx := k.SupervisorContext() + if err := k.vfs.Init(ctx); err != nil { return fmt.Errorf("failed to initialize VFS: %v", err) } @@ -380,19 +390,19 @@ func (k *Kernel) Init(args InitKernelArgs) error { if err != nil { return fmt.Errorf("failed to create pipefs filesystem: %v", err) } - defer pipeFilesystem.DecRef() + defer pipeFilesystem.DecRef(ctx) pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to create pipefs mount: %v", err) } k.pipeMount = pipeMount - tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(k.SupervisorContext(), &k.vfs, auth.NewRootCredentials(k.rootUserNamespace)) + tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(ctx, &k.vfs, auth.NewRootCredentials(k.rootUserNamespace)) if err != nil { return fmt.Errorf("failed to create tmpfs filesystem: %v", err) } - defer tmpfsFilesystem.DecRef() - defer tmpfsRoot.DecRef() + defer tmpfsFilesystem.DecRef(ctx) + defer tmpfsRoot.DecRef(ctx) shmMount, err := k.vfs.NewDisconnectedMount(tmpfsFilesystem, tmpfsRoot, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to create tmpfs mount: %v", err) @@ -403,12 +413,14 @@ func (k *Kernel) Init(args InitKernelArgs) error { if err != nil { return fmt.Errorf("failed to create sockfs filesystem: %v", err) } - defer socketFilesystem.DecRef() + defer socketFilesystem.DecRef(ctx) socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to create sockfs mount: %v", err) } k.socketMount = socketMount + + k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) } return nil @@ -426,8 +438,8 @@ func (k *Kernel) SaveTo(w wire.Writer) error { defer k.extMu.Unlock() // Stop time. - k.pauseTimeLocked() - defer k.resumeTimeLocked() + k.pauseTimeLocked(ctx) + defer k.resumeTimeLocked(ctx) // Evict all evictable MemoryFile allocations. k.mf.StartEvictions() @@ -443,12 +455,12 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Remove all epoll waiter objects from underlying wait queues. // NOTE: for programs to resume execution in future snapshot scenarios, // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters() + k.tasks.unregisterEpollWaiters(ctx) // Clear the dirent cache before saving because Dirents must be Loaded in a // particular order (parents before children), and Loading dirents from a cache // breaks that order. - if err := k.flushMountSourceRefs(); err != nil { + if err := k.flushMountSourceRefs(ctx); err != nil { return err } @@ -501,7 +513,11 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. -func (k *Kernel) flushMountSourceRefs() error { +func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { + if VFS2Enabled { + return nil // Not relevant. + } + // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -517,7 +533,7 @@ func (k *Kernel) flushMountSourceRefs() error { // There may be some open FDs whose filesystems have been unmounted. We // must flush those as well. - return k.tasks.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { file.Dirent.Inode.MountSource.FlushDirentRefs() return nil }) @@ -527,12 +543,7 @@ func (k *Kernel) flushMountSourceRefs() error { // each task. // // Precondition: Must be called with the kernel paused. -func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - +func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.FileDescription) error) (err error) { ts.mu.RLock() defer ts.mu.RUnlock() for t := range ts.Root.tids { @@ -540,7 +551,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) if t.fdTable == nil { continue } - t.fdTable.forEach(func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) { if lastErr := f(file, fileVFS2); lastErr != nil && err == nil { err = lastErr } @@ -551,7 +562,11 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error { + if VFS2Enabled { + return nil + } + + return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -598,7 +613,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { return nil } -func (ts *TaskSet) unregisterEpollWaiters() { +func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { // TODO(gvisor.dev/issue/1663): Add save support for VFS2. if VFS2Enabled { return @@ -619,7 +634,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { if _, ok := processed[t.fdTable]; ok { continue } - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { if e, ok := file.FileOperations.(*epoll.EventPoll); ok { e.UnregisterEpollWaiters() } @@ -883,20 +898,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, opener fsbridge.Lookup fsContext *FSContext mntns *fs.MountNamespace + mntnsVFS2 *vfs.MountNamespace ) if VFS2Enabled { - mntnsVFS2 := args.MountNamespaceVFS2 + mntnsVFS2 = args.MountNamespaceVFS2 if mntnsVFS2 == nil { // MountNamespaceVFS2 adds a reference to the namespace, which is // transferred to the new process. mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2() } // Get the root directory from the MountNamespace. - root := args.MountNamespaceVFS2.Root() + root := mntnsVFS2.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. - defer root.DecRef() + defer root.DecRef(ctx) // Grab the working directory. wd := root // Default. @@ -914,7 +930,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if err != nil { return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err) } - defer wd.DecRef() + defer wd.DecRef(ctx) } opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd) fsContext = NewFSContextVFS2(root, wd, args.Umask) @@ -929,7 +945,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, root := mntns.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. - defer root.DecRef() + defer root.DecRef(ctx) // Grab the working directory. remainingTraversals := args.MaxSymlinkTraversals @@ -940,7 +956,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if err != nil { return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err) } - defer wd.DecRef() + defer wd.DecRef(ctx) } opener = fsbridge.NewFSLookup(mntns, root, wd) fsContext = newFSContext(root, wd, args.Umask) @@ -1003,7 +1019,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, UTSNamespace: args.UTSNamespace, IPCNamespace: args.IPCNamespace, AbstractSocketNamespace: args.AbstractSocketNamespace, - MountNamespaceVFS2: args.MountNamespaceVFS2, + MountNamespaceVFS2: mntnsVFS2, ContainerID: args.ContainerID, } t, err := k.tasks.NewTask(config) @@ -1050,7 +1066,7 @@ func (k *Kernel) Start() error { // If k was created by LoadKernelFrom, timers were stopped during // Kernel.SaveTo and need to be resumed. If k was created by NewKernel, // this is a no-op. - k.resumeTimeLocked() + k.resumeTimeLocked(k.SupervisorContext()) // Start task goroutines. k.tasks.mu.RLock() defer k.tasks.mu.RUnlock() @@ -1062,9 +1078,10 @@ func (k *Kernel) Start() error { // pauseTimeLocked pauses all Timers and Timekeeper updates. // -// Preconditions: Any task goroutines running in k must be stopped. k.extMu -// must be locked. -func (k *Kernel) pauseTimeLocked() { +// Preconditions: +// * Any task goroutines running in k must be stopped. +// * k.extMu must be locked. +func (k *Kernel) pauseTimeLocked(ctx context.Context) { // k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before // Kernel.Start(). if k.cpuClockTicker != nil { @@ -1086,7 +1103,7 @@ func (k *Kernel) pauseTimeLocked() { // This means we'll iterate FDTables shared by multiple tasks repeatedly, // but ktime.Timer.Pause is idempotent so this is harmless. if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { if VFS2Enabled { if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok { tfd.PauseTimer() @@ -1106,9 +1123,10 @@ func (k *Kernel) pauseTimeLocked() { // pauseTimeLocked has not been previously called, resumeTimeLocked has no // effect. // -// Preconditions: Any task goroutines running in k must be stopped. k.extMu -// must be locked. -func (k *Kernel) resumeTimeLocked() { +// Preconditions: +// * Any task goroutines running in k must be stopped. +// * k.extMu must be locked. +func (k *Kernel) resumeTimeLocked(ctx context.Context) { if k.cpuClockTicker != nil { k.cpuClockTicker.Resume() } @@ -1122,7 +1140,7 @@ func (k *Kernel) resumeTimeLocked() { } } if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { if VFS2Enabled { if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok { tfd.ResumeTimer() @@ -1258,6 +1276,13 @@ func (k *Kernel) Pause() { k.tasks.aioGoroutines.Wait() } +// ReceiveTaskStates receives full states for all tasks. +func (k *Kernel) ReceiveTaskStates() { + k.extMu.Lock() + k.tasks.PullFullState() + k.extMu.Unlock() +} + // Unpause ends the effect of a previous call to Pause. If Unpause is called // without a matching preceding call to Pause, Unpause may panic. func (k *Kernel) Unpause() { @@ -1465,6 +1490,11 @@ func (k *Kernel) NowMonotonic() int64 { return now } +// AfterFunc implements tcpip.Clock.AfterFunc. +func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(k.realtimeClock, d, f) +} + // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { @@ -1489,20 +1519,27 @@ func (k *Kernel) SupervisorContext() context.Context { } } -// SocketEntry represents a socket recorded in Kernel.sockets. It implements +// SocketRecord represents a socket recorded in Kernel.socketsVFS2. +// +// +stateify savable +type SocketRecord struct { + k *Kernel + Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1. + SockVFS2 *vfs.FileDescription // Only used by VFS2. + ID uint64 // Socket table entry number. +} + +// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type SocketEntry struct { +type SocketRecordVFS1 struct { socketEntry - k *Kernel - Sock *refs.WeakRef - SockVFS2 *vfs.FileDescription - ID uint64 // Socket table entry number. + SocketRecord } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *SocketEntry) WeakRefGone() { +func (s *SocketRecordVFS1) WeakRefGone(context.Context) { s.k.extMu.Lock() s.k.sockets.Remove(s) s.k.extMu.Unlock() @@ -1513,9 +1550,14 @@ func (s *SocketEntry) WeakRefGone() { // Precondition: Caller must hold a reference to sock. func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{k: k, ID: id} + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecordVFS1{ + SocketRecord: SocketRecord{ + k: k, + ID: id, + }, + } s.Sock = refs.NewWeakRef(sock, s) k.sockets.PushBack(s) k.extMu.Unlock() @@ -1527,29 +1569,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) { // Precondition: Caller must hold a reference to sock. // // Note that the socket table will not hold a reference on the -// vfs.FileDescription, because we do not support weak refs on VFS2 files. +// vfs.FileDescription. func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{ + if _, ok := k.socketsVFS2[sock]; ok { + panic(fmt.Sprintf("Socket %p added twice", sock)) + } + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecord{ k: k, ID: id, SockVFS2: sock, } - k.sockets.PushBack(s) + k.socketsVFS2[sock] = s + k.extMu.Unlock() +} + +// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table. +func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) { + k.extMu.Lock() + delete(k.socketsVFS2, sock) k.extMu.Unlock() } // ListSockets returns a snapshot of all sockets. // -// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef() +// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef() // to get a reference on a socket in the table. -func (k *Kernel) ListSockets() []*SocketEntry { +func (k *Kernel) ListSockets() []*SocketRecord { k.extMu.Lock() - var socks []*SocketEntry - for s := k.sockets.Front(); s != nil; s = s.Next() { - socks = append(socks, s) + var socks []*SocketRecord + if VFS2Enabled { + for _, s := range k.socketsVFS2 { + socks = append(socks, s) + } + } else { + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, &s.SocketRecord) + } } k.extMu.Unlock() return socks @@ -1591,7 +1649,7 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { return vfs.VirtualDentry{} } mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2() - defer mntns.DecRef() + defer mntns.DecRef(ctx) // Root() takes a reference on the root dirent for us. return mntns.Root() case vfs.CtxMountNamespace: diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 449643118..99134e634 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/amutex", "//pkg/buffer", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 4b688c627..6497dc4ba 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -93,7 +93,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { if !waitFor(&i.mu, &i.wWakeup, ctx) { - r.DecRef() + r.DecRef(ctx) return nil, syserror.ErrInterrupted } } @@ -111,12 +111,12 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi // On a nonblocking, write-only open, the open fails with ENXIO if the // read side isn't open yet. if flags.NonBlocking { - w.DecRef() + w.DecRef(ctx) return nil, syserror.ENXIO } if !waitFor(&i.mu, &i.rWakeup, ctx) { - w.DecRef() + w.DecRef(ctx) return nil, syserror.ErrInterrupted } } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index ab75a87ff..ce0db5583 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -167,7 +167,7 @@ func TestClosedReaderBlocksWriteOpen(t *testing.T) { f := NewInodeOperations(ctx, perms, newNamedPipe(t)) rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil) - rFile.DecRef() + rFile.DecRef(ctx) wDone := make(chan struct{}) // This open for write should block because the reader is now gone. diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 79645d7d2..67beb0ad6 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -17,6 +17,7 @@ package pipe import ( "fmt" + "io" "sync/atomic" "syscall" @@ -152,7 +153,7 @@ func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs. d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino)) // The p.Open calls below will each take a reference on the Dirent. We // must drop the one we already have. - defer d.DecRef() + defer d.DecRef(ctx) return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true}) } @@ -200,22 +201,22 @@ type readOps struct { // // Precondition: this pipe must have readers. func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { - // Don't block for a zero-length read even if the pipe is empty. - if ops.left() == 0 { - return 0, nil - } - p.mu.Lock() defer p.mu.Unlock() return p.readLocked(ctx, ops) } func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { + // Don't block for a zero-length read even if the pipe is empty. + if ops.left() == 0 { + return 0, nil + } + // Is the pipe empty? if p.view.Size() == 0 { if !p.HasWriters() { // There are no writers, return EOF. - return 0, nil + return 0, io.EOF } return 0, syserror.ErrWouldBlock } @@ -388,6 +389,10 @@ func (p *Pipe) rwReadiness() waiter.EventMask { func (p *Pipe) queued() int64 { p.mu.Lock() defer p.mu.Unlock() + return p.queuedLocked() +} + +func (p *Pipe) queuedLocked() int64 { return p.view.Size() } diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index bda739dbe..fe97e9800 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -27,8 +27,8 @@ import ( func TestPipeRW(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) msg := []byte("here's some bytes") wantN := int64(len(msg)) @@ -47,8 +47,8 @@ func TestPipeRW(t *testing.T) { func TestPipeReadBlock(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1))) if n != 0 || err != syserror.ErrWouldBlock { @@ -62,8 +62,8 @@ func TestPipeWriteBlock(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) msg := make([]byte, capacity+1) n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) @@ -77,8 +77,8 @@ func TestPipeWriteUntilEnd(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) msg := []byte("here's some bytes") diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index aacf28da2..f665920cb 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -33,7 +34,7 @@ import ( // the old fs architecture. // Release cleans up the pipe's state. -func (p *Pipe) Release() { +func (p *Pipe) Release(context.Context) { p.rClose() p.wClose() @@ -145,9 +146,14 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume v = math.MaxInt32 // Silently truncate. } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + iocc := primitive.IOCopyContext{ + IO: io, + Ctx: ctx, + Opts: usermem.IOOpts{ + AddressSpaceActive: true, + }, + } + _, err := primitive.CopyInt32Out(&iocc, args[2].Pointer(), int32(v)) return 0, err default: return 0, syscall.ENOTTY diff --git a/pkg/sentry/kernel/pipe/reader.go b/pkg/sentry/kernel/pipe/reader.go index 7724b4452..ac18785c0 100644 --- a/pkg/sentry/kernel/pipe/reader.go +++ b/pkg/sentry/kernel/pipe/reader.go @@ -15,6 +15,7 @@ package pipe import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/waiter" ) @@ -29,7 +30,7 @@ type Reader struct { // Release implements fs.FileOperations.Release. // // This overrides ReaderWriter.Release. -func (r *Reader) Release() { +func (r *Reader) Release(context.Context) { r.Pipe.rClose() // Wake up writers. diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 45d4c5fc1..f61039f5b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -67,6 +67,11 @@ func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlag return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (*VFSPipe) Allocate(context.Context, uint64, uint64, uint64) error { + return syserror.ESPIPE +} + // Open opens the pipe represented by vp. func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() @@ -101,7 +106,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s // If this pipe is being opened as blocking and there's no // writer, we have to wait for a writer to open the other end. if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EINTR } @@ -112,12 +117,12 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s // Non-blocking, write-only opens fail with ENXIO when the read // side isn't open yet. if statusFlags&linux.O_NONBLOCK != 0 { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.ENXIO } // Wait for a reader to open the other end. if !waitFor(&vp.mu, &vp.rWakeup, ctx) { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EINTR } } @@ -169,7 +174,7 @@ type VFSPipeFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *VFSPipeFD) Release() { +func (fd *VFSPipeFD) Release(context.Context) { var event waiter.EventMask if fd.vfsfd.IsReadable() { fd.pipe.rClose() @@ -244,19 +249,57 @@ func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { return fd.pipe.SetFifoSize(size) } -// IOSequence returns a useremm.IOSequence that reads up to count bytes from, -// or writes up to count bytes to, fd. -func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence { - return usermem.IOSequence{ +// SpliceToNonPipe performs a splice operation from fd to a non-pipe file. +func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescription, off, count int64) (int64, error) { + fd.pipe.mu.Lock() + defer fd.pipe.mu.Unlock() + + // Cap the sequence at number of bytes actually available. + v := fd.pipe.queuedLocked() + if v < count { + count = v + } + src := usermem.IOSequence{ IO: fd, Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), } + + var ( + n int64 + err error + ) + if off == -1 { + n, err = out.Write(ctx, src, vfs.WriteOptions{}) + } else { + n, err = out.PWrite(ctx, src, off, vfs.WriteOptions{}) + } + if n > 0 { + fd.pipe.view.TrimFront(n) + } + return n, err +} + +// SpliceFromNonPipe performs a splice operation from a non-pipe file to fd. +func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) { + fd.pipe.mu.Lock() + defer fd.pipe.mu.Unlock() + + dst := usermem.IOSequence{ + IO: fd, + Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + } + + if off == -1 { + return in.Read(ctx, dst, vfs.ReadOptions{}) + } + return in.PRead(ctx, dst, off, vfs.ReadOptions{}) } -// CopyIn implements usermem.IO.CopyIn. +// CopyIn implements usermem.IO.CopyIn. Note that it is the caller's +// responsibility to trim fd.pipe.view after the read is completed. func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { origCount := int64(len(dst)) - n, err := fd.pipe.read(ctx, readOps{ + n, err := fd.pipe.readLocked(ctx, readOps{ left: func() int64 { return int64(len(dst)) }, @@ -265,7 +308,6 @@ func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, }, read: func(view *buffer.View) (int64, error) { n, err := view.ReadAt(dst, 0) - view.TrimFront(int64(n)) return int64(n), err }, }) @@ -281,7 +323,7 @@ func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, // CopyOut implements usermem.IO.CopyOut. func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { origCount := int64(len(src)) - n, err := fd.pipe.write(ctx, writeOps{ + n, err := fd.pipe.writeLocked(ctx, writeOps{ left: func() int64 { return int64(len(src)) }, @@ -305,7 +347,7 @@ func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, // ZeroOut implements usermem.IO.ZeroOut. func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { origCount := toZero - n, err := fd.pipe.write(ctx, writeOps{ + n, err := fd.pipe.writeLocked(ctx, writeOps{ left: func() int64 { return toZero }, @@ -326,14 +368,15 @@ func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int6 return n, err } -// CopyInTo implements usermem.IO.CopyInTo. +// CopyInTo implements usermem.IO.CopyInTo. Note that it is the caller's +// responsibility to trim fd.pipe.view after the read is completed. func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { count := ars.NumBytes() if count == 0 { return 0, nil } origCount := count - n, err := fd.pipe.read(ctx, readOps{ + n, err := fd.pipe.readLocked(ctx, readOps{ left: func() int64 { return count }, @@ -342,7 +385,6 @@ func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst }, read: func(view *buffer.View) (int64, error) { n, err := view.ReadToSafememWriter(dst, uint64(count)) - view.TrimFront(int64(n)) return int64(n), err }, }) @@ -362,7 +404,7 @@ func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, return 0, nil } origCount := count - n, err := fd.pipe.write(ctx, writeOps{ + n, err := fd.pipe.writeLocked(ctx, writeOps{ left: func() int64 { return count }, diff --git a/pkg/sentry/kernel/pipe/writer.go b/pkg/sentry/kernel/pipe/writer.go index 5bc6aa931..ef4b70ca3 100644 --- a/pkg/sentry/kernel/pipe/writer.go +++ b/pkg/sentry/kernel/pipe/writer.go @@ -15,6 +15,7 @@ package pipe import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/waiter" ) @@ -29,7 +30,7 @@ type Writer struct { // Release implements fs.FileOperations.Release. // // This overrides ReaderWriter.Release. -func (w *Writer) Release() { +func (w *Writer) Release(context.Context) { w.Pipe.wClose() // Wake up readers. diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index e23e796ef..1145faf13 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" @@ -224,8 +225,9 @@ func (s *ptraceStop) Killable() bool { // beginPtraceStopLocked does not signal t's tracer or wake it if it is // waiting. // -// Preconditions: The TaskSet mutex must be locked. The caller must be running -// on the task goroutine. +// Preconditions: +// * The TaskSet mutex must be locked. +// * The caller must be running on the task goroutine. func (t *Task) beginPtraceStopLocked() bool { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() @@ -270,8 +272,9 @@ func (t *Task) ptraceTrapLocked(code int32) { // ptraceStop, temporarily preventing it from being removed by a concurrent // Task.Kill, and returns true. Otherwise it returns false. // -// Preconditions: The TaskSet mutex must be locked. The caller must be running -// on the task goroutine of t's tracer. +// Preconditions: +// * The TaskSet mutex must be locked. +// * The caller must be running on the task goroutine of t's tracer. func (t *Task) ptraceFreeze() bool { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() @@ -301,8 +304,9 @@ func (t *Task) ptraceUnfreeze() { t.ptraceUnfreezeLocked() } -// Preconditions: t must be in a frozen ptraceStop. t's signal mutex must be -// locked. +// Preconditions: +// * t must be in a frozen ptraceStop. +// * t's signal mutex must be locked. func (t *Task) ptraceUnfreezeLocked() { // Do this even if the task has been killed to ensure a panic if t.stop is // nil or not a ptraceStop. @@ -497,8 +501,9 @@ func (t *Task) forgetTracerLocked() { // ptraceSignalLocked is called after signal dequeueing to check if t should // enter ptrace signal-delivery-stop. // -// Preconditions: The signal mutex must be locked. The caller must be running -// on the task goroutine. +// Preconditions: +// * The signal mutex must be locked. +// * The caller must be running on the task goroutine. func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool { if linux.Signal(info.Signo) == linux.SIGKILL { return false @@ -828,8 +833,9 @@ func (t *Task) ptraceInterrupt(target *Task) error { return nil } -// Preconditions: The TaskSet mutex must be locked for writing. t must have a -// tracer. +// Preconditions: +// * The TaskSet mutex must be locked for writing. +// * t must have a tracer. func (t *Task) ptraceSetOptionsLocked(opts uintptr) error { const valid = uintptr(linux.PTRACE_O_EXITKILL | linux.PTRACE_O_TRACESYSGOOD | @@ -994,18 +1000,15 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{ - IgnorePermissions: true, - }); err != nil { + if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } - _, err := t.CopyOut(data, word) + _, err := word.CopyOut(t, data) return err case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: - _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{ - IgnorePermissions: true, - }) + word := t.Arch().Native(uintptr(data)) + _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: @@ -1018,6 +1021,9 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if err != nil { return err } + + t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) + ar := ars.Head() n, err := target.Arch().PtraceGetRegSet(uintptr(addr), &usermem.IOReadWriter{ Ctx: t, @@ -1044,10 +1050,14 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if err != nil { return err } + + mm := t.MemoryManager() + t.p.PullFullState(mm.AddressSpace(), t.Arch()) + ar := ars.Head() n, err := target.Arch().PtraceSetRegSet(uintptr(addr), &usermem.IOReadWriter{ Ctx: t, - IO: t.MemoryManager(), + IO: mm, Addr: ar.Start, Opts: usermem.IOOpts{ AddressSpaceActive: true, @@ -1056,6 +1066,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if err != nil { return err } + t.p.FullStateChanged() ar.End -= usermem.Addr(n) return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar)) @@ -1065,12 +1076,12 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if target.ptraceSiginfo == nil { return syserror.EINVAL } - _, err := t.CopyOut(data, target.ptraceSiginfo) + _, err := target.ptraceSiginfo.CopyOut(t, data) return err case linux.PTRACE_SETSIGINFO: var info arch.SignalInfo - if _, err := t.CopyIn(data, &info); err != nil { + if _, err := info.CopyIn(t, data); err != nil { return err } t.tg.pidns.owner.mu.RLock() @@ -1085,7 +1096,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if addr != linux.SignalSetSize { return syserror.EINVAL } - _, err := t.CopyOut(data, target.SignalMask()) + mask := target.SignalMask() + _, err := mask.CopyOut(t, data) return err case linux.PTRACE_SETSIGMASK: @@ -1093,7 +1105,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { return syserror.EINVAL } var mask linux.SignalSet - if _, err := t.CopyIn(data, &mask); err != nil { + if _, err := mask.CopyIn(t, data); err != nil { return err } // The target's task goroutine is stopped, so this is safe: @@ -1108,7 +1120,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_GETEVENTMSG: t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() - _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg) + _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg) return err // PEEKSIGINFO is unimplemented but seems to have no users anywhere. diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index cef1276ec..609ad3941 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -30,7 +30,7 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro if err != nil { return err } - _, err = t.CopyOut(data, n) + _, err = n.CopyOut(t, data) return err case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 18416643b..2a9023fdf 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -173,8 +173,10 @@ func (t *Task) OldRSeqCPUAddr() usermem.Addr { // SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with // t's CPU number. // -// Preconditions: t.RSeqAvailable() == true. The caller must be running on the -// task goroutine. t's AddressSpace must be active. +// Preconditions: +// * t.RSeqAvailable() == true. +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error { t.oldRSeqCPUAddr = addr @@ -189,8 +191,9 @@ func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error { return nil } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqUpdateCPU() error { if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 { t.rseqCPU = -1 @@ -209,8 +212,9 @@ func (t *Task) rseqUpdateCPU() error { return oerr } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) oldRSeqCopyOutCPU() error { if t.oldRSeqCPUAddr == 0 { return nil @@ -222,8 +226,9 @@ func (t *Task) oldRSeqCopyOutCPU() error { return err } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqCopyOutCPU() error { if t.rseqAddr == 0 { return nil @@ -240,8 +245,9 @@ func (t *Task) rseqCopyOutCPU() error { return err } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqClearCPU() error { buf := t.CopyScratchBuffer(8) // CPUIDStart and CPUID are the first two fields in linux.RSeq. @@ -269,8 +275,9 @@ func (t *Task) rseqClearCPU() error { // // See kernel/rseq.c:rseq_ip_fixup for reference. // -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqAddrInterrupt() { if t.rseqAddr == 0 { return diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index c38c5a40c..387edfa91 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -18,7 +18,6 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" @@ -27,25 +26,18 @@ import ( const maxSyscallFilterInstructions = 1 << 15 -// seccompData is equivalent to struct seccomp_data, which contains the data -// passed to seccomp-bpf filters. -type seccompData struct { - // nr is the system call number. - nr int32 - - // arch is an AUDIT_ARCH_* value indicating the system call convention. - arch uint32 - - // instructionPointer is the value of the instruction pointer at the time - // of the system call. - instructionPointer uint64 - - // args contains the first 6 system call arguments. - args [6]uint64 -} - -func (d *seccompData) asBPFInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder} +// dataAsBPFInput returns a serialized BPF program, only valid on the current task +// goroutine. +// +// Note: this is called for every syscall, which is a very hot path. +func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input { + buf := t.CopyScratchBuffer(d.SizeBytes()) + d.MarshalUnsafe(buf) + return bpf.InputBytes{ + Data: buf, + // Go-marshal always uses the native byte order. + Order: usermem.ByteOrder, + } } func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo { @@ -112,20 +104,20 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u } func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 { - data := seccompData{ - nr: sysno, - arch: t.tc.st.AuditNumber, - instructionPointer: uint64(ip), + data := linux.SeccompData{ + Nr: sysno, + Arch: t.tc.st.AuditNumber, + InstructionPointer: uint64(ip), } // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so // we can't do any slicing tricks or even use copy/append here. for i, arg := range args { - if i >= len(data.args) { + if i >= len(data.Args) { break } - data.args[i] = arg.Uint64() + data.Args[i] = arg.Uint64() } - input := data.asBPFInput() + input := dataAsBPFInput(t, &data) ret := uint32(linux.SECCOMP_RET_ALLOW) f := t.syscallFilters.Load() diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 0e19286de..df5c8421b 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -16,7 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" ) @@ -31,7 +30,7 @@ type ProcessGroupID ThreadID // // +stateify savable type Session struct { - refs refs.AtomicRefCount + SessionRefs // leader is the originator of the Session. // @@ -61,16 +60,11 @@ type Session struct { sessionEntry } -// incRef grabs a reference. -func (s *Session) incRef() { - s.refs.IncRef() -} - -// decRef drops a reference. +// DecRef drops a reference. // // Precondition: callers must hold TaskSet.mu for writing. -func (s *Session) decRef() { - s.refs.DecRefWithDestructor(func() { +func (s *Session) DecRef() { + s.SessionRefs.DecRef(func() { // Remove translations from the leader. for ns := s.leader.pidns; ns != nil; ns = ns.parent { id := ns.sids[s] @@ -87,7 +81,7 @@ func (s *Session) decRef() { // // +stateify savable type ProcessGroup struct { - refs refs.AtomicRefCount // not exported. + refs ProcessGroupRefs // originator is the originator of the group. // @@ -162,7 +156,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) { } alive := true - pg.refs.DecRefWithDestructor(func() { + pg.refs.DecRef(func() { alive = false // don't bother with handleOrphan. // Remove translations from the originator. @@ -174,7 +168,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) { // Remove the list of process groups. pg.session.processGroups.Remove(pg) - pg.session.decRef() + pg.session.DecRef() }) if alive { pg.handleOrphan() @@ -301,7 +295,7 @@ func (tg *ThreadGroup) createSession() error { id: SessionID(id), leader: tg, } - s.refs.EnableLeakCheck("kernel.Session") + s.EnableLeakCheck() // Create a new ProcessGroup, belonging to that Session. // This also has a single reference (assigned below). @@ -315,7 +309,7 @@ func (tg *ThreadGroup) createSession() error { session: s, ancestors: 0, } - pg.refs.EnableLeakCheck("kernel.ProcessGroup") + pg.refs.EnableLeakCheck() // Tie them and return the result. s.processGroups.PushBack(pg) @@ -395,13 +389,13 @@ func (tg *ThreadGroup) CreateProcessGroup() error { // // We manually adjust the ancestors if the parent is in the same // session. - tg.processGroup.session.incRef() + tg.processGroup.session.IncRef() pg := ProcessGroup{ id: ProcessGroupID(id), originator: tg, session: tg.processGroup.session, } - pg.refs.EnableLeakCheck("kernel.ProcessGroup") + pg.refs.EnableLeakCheck() if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session { pg.ancestors++ diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index bfd779837..b7e4b480d 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,12 +1,25 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "shm_refs", + out = "shm_refs.go", + package = "shm", + prefix = "Shm", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Shm", + }, +) + go_library( name = "shm", srcs = [ "device.go", "shm.go", + "shm_refs.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -20,7 +33,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index f66cfcc7f..00c03585e 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -39,13 +39,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -253,7 +251,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi creatorPID: pid, changeTime: ktime.NowFromContext(ctx), } - shm.EnableLeakCheck("kernel.Shm") + shm.EnableLeakCheck() // Find the next available ID. for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ { @@ -338,14 +336,14 @@ func (r *Registry) remove(s *Shm) { // // +stateify savable type Shm struct { - // AtomicRefCount tracks the number of references to this segment. + // ShmRefs tracks the number of references to this segment. // // A segment holds a reference to itself until it is marked for // destruction. // // In addition to direct users, the MemoryManager will hold references // via MappingIdentity. - refs.AtomicRefCount + ShmRefs mfp pgalloc.MemoryFileProvider @@ -370,7 +368,7 @@ type Shm struct { // fr is the offset into mfp.MemoryFile() that backs this contents of this // segment. Immutable. - fr platform.FileRange + fr memmap.FileRange // mu protects all fields below. mu sync.Mutex `state:"nosave"` @@ -429,11 +427,14 @@ func (s *Shm) InodeID() uint64 { return uint64(s.ID) } -// DecRef overrides refs.RefCount.DecRef with a destructor. +// DecRef drops a reference on s. // // Precondition: Caller must not hold s.mu. -func (s *Shm) DecRef() { - s.DecRefWithDestructor(s.destroy) +func (s *Shm) DecRef(ctx context.Context) { + s.ShmRefs.DecRef(func() { + s.mfp.MemoryFile().DecRef(s.fr) + s.registry.remove(s) + }) } // Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm @@ -643,16 +644,11 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error { return nil } -func (s *Shm) destroy() { - s.mfp.MemoryFile().DecRef(s.fr) - s.registry.remove(s) -} - // MarkDestroyed marks a segment for destruction. The segment is actually // destroyed once it has no references. MarkDestroyed may be called multiple // times, and is safe to call after a segment has already been destroyed. See // shmctl(IPC_RMID). -func (s *Shm) MarkDestroyed() { +func (s *Shm) MarkDestroyed(ctx context.Context) { s.registry.dissociateKey(s) s.mu.Lock() @@ -664,7 +660,7 @@ func (s *Shm) MarkDestroyed() { // // N.B. This cannot be the final DecRef, as the caller also // holds a reference. - s.DecRef() + s.DecRef(ctx) return } } diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 3eb78e91b..76d472292 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 8243bb93e..78f718cfe 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -17,7 +17,6 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" @@ -76,7 +75,7 @@ func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { } // Release implements fs.FileOperations.Release. -func (s *SignalOperations) Release() {} +func (s *SignalOperations) Release(context.Context) {} // Mask returns the signal mask. func (s *SignalOperations) Mask() linux.SignalSet { @@ -103,8 +102,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // Copy out the signal info using the specified format. - var buf [128]byte - binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + infoNative := linux.SignalfdSiginfo{ Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, @@ -113,9 +111,13 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), - }) - n, err := dst.CopyOut(ctx, buf[:]) - return int64(n), err + } + n, err := infoNative.WriteTo(dst.Writer(ctx)) + if err == usermem.ErrEndOfIOSequence { + // Partial copy-out ok. + err = nil + } + return n, err } // Readiness implements waiter.Waitable.Readiness. diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 413111faf..332bdb8e8 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -348,6 +348,16 @@ func (s *SyscallTable) LookupName(sysno uintptr) string { return fmt.Sprintf("sys_%d", sysno) // Unlikely. } +// LookupNo looks up a syscall number by name. +func (s *SyscallTable) LookupNo(name string) (uintptr, error) { + for i, syscall := range s.Table { + if syscall.Name == name { + return uintptr(i), nil + } + } + return 0, fmt.Errorf("syscall %q not found", name) +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 4607cde2f..a83ce219c 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -98,6 +98,15 @@ func (s *syslog) Log() []byte { s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...) } + if VFS2Enabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...) + if FUSEEnabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...) + } + } + time += rand.Float64() / 2 s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index f48247c94..f796e0fa3 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -68,6 +68,21 @@ type Task struct { // runState is exclusive to the task goroutine. runState taskRunState + // taskWorkCount represents the current size of the task work queue. It is + // used to avoid acquiring taskWorkMu when the queue is empty. + // + // Must accessed with atomic memory operations. + taskWorkCount int32 + + // taskWorkMu protects taskWork. + taskWorkMu sync.Mutex `state:"nosave"` + + // taskWork is a queue of work to be executed before resuming user execution. + // It is similar to the task_work mechanism in Linux. + // + // taskWork is exclusive to the task goroutine. + taskWork []TaskWorker + // haveSyscallReturn is true if tc.Arch().Return() represents a value // returned by a syscall (or set by ptrace after a syscall). // @@ -550,11 +565,20 @@ type Task struct { // futexWaiter is exclusive to the task goroutine. futexWaiter *futex.Waiter `state:"nosave"` + // robustList is a pointer to the head of the tasks's robust futex + // list. + robustList usermem.Addr + // startTime is the real time at which the task started. It is set when // a Task is created or invokes execve(2). // // startTime is protected by mu. startTime ktime.Time + + // kcov is the kcov instance providing code coverage owned by this task. + // + // kcov is exclusive to the task goroutine. + kcov *Kcov } func (t *Task) savePtraceTracer() *Task { @@ -711,17 +735,17 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock { func (t *Task) IsChrooted() bool { if VFS2Enabled { realRoot := t.mountNamespaceVFS2.Root() - defer realRoot.DecRef() + defer realRoot.DecRef(t) root := t.fsContext.RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) return root != realRoot } realRoot := t.tg.mounts.Root() - defer realRoot.DecRef() + defer realRoot.DecRef(t) root := t.fsContext.RootDirectory() if root != nil { - defer root.DecRef() + defer root.DecRef(t) } return root != realRoot } @@ -884,3 +908,16 @@ func (t *Task) UID() uint32 { func (t *Task) GID() uint32 { return uint32(t.Credentials().EffectiveKGID) } + +// SetKcov sets the kcov instance associated with t. +func (t *Task) SetKcov(k *Kcov) { + t.kcov = k +} + +// ResetKcov clears the kcov instance associated with t. +func (t *Task) ResetKcov() { + if t.kcov != nil { + t.kcov.OnTaskExit() + t.kcov = nil + } +} diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index e1ecca99e..fce1064a7 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -161,6 +161,10 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { return 0, nil, syserror.EINVAL } + // Pull task registers and FPU state, a cloned task will inherit the + // state of the current task. + t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) + // "If CLONE_NEWUSER is specified along with other CLONE_NEW* flags in a // single clone(2) or unshare(2) call, the user namespace is guaranteed to // be created first, giving the child (clone(2)) or caller (unshare(2)) @@ -237,7 +241,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { var fdTable *FDTable if opts.NewFiles { - fdTable = t.fdTable.Fork() + fdTable = t.fdTable.Fork(t) } else { fdTable = t.fdTable fdTable.IncRef() @@ -294,7 +298,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { nt, err := t.tg.pidns.owner.NewTask(cfg) if err != nil { if opts.NewThreadGroup { - tg.release() + tg.release(t) } return 0, nil, err } @@ -337,12 +341,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { nt.SetClearTID(opts.ChildTID) } if opts.ChildSetTID { - // Can't use Task.CopyOut, which assumes AddressSpaceActive. - usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{}) + ctid := nt.ThreadID() + ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { - t.CopyOut(opts.ParentTID, ntid) + ntid.CopyOut(t, opts.ParentTID) } kind := ptraceCloneKindClone @@ -510,7 +514,7 @@ func (t *Task) Unshare(opts *SharingOptions) error { var oldFDTable *FDTable if opts.NewFiles { oldFDTable = t.fdTable - t.fdTable = oldFDTable.Fork() + t.fdTable = oldFDTable.Fork(t) } var oldFSContext *FSContext if opts.NewFSContext { @@ -519,10 +523,10 @@ func (t *Task) Unshare(opts *SharingOptions) error { } t.mu.Unlock() if oldFDTable != nil { - oldFDTable.DecRef() + oldFDTable.DecRef(t) } if oldFSContext != nil { - oldFSContext.DecRef() + oldFSContext.DecRef(t) } return nil } diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 9fa528384..d1136461a 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -126,7 +126,11 @@ func (t *Task) SyscallTable() *SyscallTable { // Preconditions: The caller must be running on the task goroutine, or t.mu // must be locked. func (t *Task) Stack() *arch.Stack { - return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} + return &arch.Stack{ + Arch: t.Arch(), + IO: t.MemoryManager(), + Bottom: usermem.Addr(t.Arch().Stack()), + } } // LoadTaskImage loads a specified file into a new TaskContext. diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 9b69f3cbe..412d471d3 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -199,14 +199,17 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.Unlock() oldFDTable := t.fdTable - t.fdTable = t.fdTable.Fork() - oldFDTable.DecRef() + t.fdTable = t.fdTable.Fork(t) + oldFDTable.DecRef(t) // Remove FDs with the CloseOnExec flag set. - t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool { + t.fdTable.RemoveIf(t, func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool { return flags.CloseOnExec }) + // Handle the robust futex list. + t.exitRobustList() + // NOTE(b/30815691): We currently do not implement privileged // executables (set-user/group-ID bits and file capabilities). This // allows us to unconditionally enable user dumpability on the new mm. @@ -223,6 +226,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tc = *r.tc t.mu.Unlock() t.unstopVforkParent() + t.p.FullStateChanged() // NOTE(b/30316266): All locks must be dropped prior to calling Activate. t.MemoryManager().Activate(t) @@ -233,9 +237,10 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // promoteLocked makes t the leader of its thread group. If t is already the // thread group leader, promoteLocked is a no-op. // -// Preconditions: All other tasks in t's thread group, including the existing -// leader (if it is not t), have reached TaskExitZombie. The TaskSet mutex must -// be locked for writing. +// Preconditions: +// * All other tasks in t's thread group, including the existing leader (if it +// is not t), have reached TaskExitZombie. +// * The TaskSet mutex must be locked for writing. func (t *Task) promoteLocked() { oldLeader := t.tg.leader if t == oldLeader { diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c4ade6e8e..b400a8b41 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -239,6 +239,8 @@ func (*runExitMain) execute(t *Task) taskRunState { t.traceExitEvent() lastExiter := t.exitThreadGroup() + t.ResetKcov() + // If the task has a cleartid, and the thread group wasn't killed by a // signal, handle that before releasing the MM. if t.cleartid != 0 { @@ -246,13 +248,17 @@ func (*runExitMain) execute(t *Task) taskRunState { signaled := t.tg.exiting && t.tg.exitStatus.Signaled() t.tg.signalHandlers.mu.Unlock() if !signaled { - if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil { + zero := ThreadID(0) + if _, err := zero.CopyOut(t, t.cleartid); err == nil { t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1) } // If the CopyOut fails, there's nothing we can do. } } + // Handle the robust futex list. + t.exitRobustList() + // Deactivate the address space and update max RSS before releasing the // task's MM. t.Deactivate() @@ -266,12 +272,12 @@ func (*runExitMain) execute(t *Task) taskRunState { // Releasing the MM unblocks a blocked CLONE_VFORK parent. t.unstopVforkParent() - t.fsContext.DecRef() - t.fdTable.DecRef() + t.fsContext.DecRef(t) + t.fdTable.DecRef(t) t.mu.Lock() if t.mountNamespaceVFS2 != nil { - t.mountNamespaceVFS2.DecRef() + t.mountNamespaceVFS2.DecRef(t) t.mountNamespaceVFS2 = nil } t.mu.Unlock() @@ -279,7 +285,7 @@ func (*runExitMain) execute(t *Task) taskRunState { // If this is the last task to exit from the thread group, release the // thread group's resources. if lastExiter { - t.tg.release() + t.tg.release(t) } // Detach tracees. diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index a53e77c9f..c80391475 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -15,6 +15,8 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" ) @@ -52,3 +54,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) { func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) { return t.MemoryManager().GetSharedFutexKey(t, addr) } + +// GetRobustList sets the robust futex list for the task. +func (t *Task) GetRobustList() usermem.Addr { + t.mu.Lock() + addr := t.robustList + t.mu.Unlock() + return addr +} + +// SetRobustList sets the robust futex list for the task. +func (t *Task) SetRobustList(addr usermem.Addr) { + t.mu.Lock() + t.robustList = addr + t.mu.Unlock() +} + +// exitRobustList walks the robust futex list, marking locks dead and notifying +// wakers. It corresponds to Linux's exit_robust_list(). Following Linux, +// errors are silently ignored. +func (t *Task) exitRobustList() { + t.mu.Lock() + addr := t.robustList + t.robustList = 0 + t.mu.Unlock() + + if addr == 0 { + return + } + + var rl linux.RobustListHead + if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil { + return + } + + next := primitive.Uint64(rl.List) + done := 0 + var pendingLockAddr usermem.Addr + if rl.ListOpPending != 0 { + pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset) + } + + // Wake up normal elements. + for usermem.Addr(next) != addr { + // We traverse to the next element of the list before we + // actually wake anything. This prevents the race where waking + // this futex causes a modification of the list. + thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset) + + // Try to decode the next element in the list before waking the + // current futex. But don't check the error until after we've + // woken the current futex. Linux does it in this order too + _, nextErr := next.CopyIn(t, usermem.Addr(next)) + + // Wakeup the current futex if it's not pending. + if thisLockAddr != pendingLockAddr { + t.wakeRobustListOne(thisLockAddr) + } + + // If there was an error copying the next futex, we must bail. + if nextErr != nil { + break + } + + // This is a user structure, so it could be a massive list, or + // even contain a loop if they are trying to mess with us. We + // cap traversal to prevent that. + done++ + if done >= linux.ROBUST_LIST_LIMIT { + break + } + } + + // Is there a pending entry to wake? + if pendingLockAddr != 0 { + t.wakeRobustListOne(pendingLockAddr) + } +} + +// wakeRobustListOne wakes a single futex from the robust list. +func (t *Task) wakeRobustListOne(addr usermem.Addr) { + // Bit 0 in address signals PI futex. + pi := addr&1 == 1 + addr = addr &^ 1 + + // Load the futex. + f, err := t.LoadUint32(addr) + if err != nil { + // Can't read this single value? Ignore the problem. + // We can wake the other futexes in the list. + return + } + + tid := uint32(t.ThreadID()) + for { + // Is this held by someone else? + if f&linux.FUTEX_TID_MASK != tid { + return + } + + // This thread is dying and it's holding this futex. We need to + // set the owner died bit and wake up any waiters. + newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED + if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil { + return + } else if curF != f { + // Futex changed out from under us. Try again... + f = curF + continue + } + + // Wake waiters if there are any. + if f&linux.FUTEX_WAITERS != 0 { + private := f&linux.FUTEX_PRIVATE_FLAG != 0 + if pi { + t.Futex().UnlockPI(t, addr, tid, private) + return + } + t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1) + } + + // Done. + return + } +} diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index eeccaa197..d23cea802 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -27,6 +27,9 @@ const ( // maxStackDebugBytes is the maximum number of user stack bytes that may be // printed by debugDumpStack. maxStackDebugBytes = 1024 + // maxCodeDebugBytes is the maximum number of user code bytes that may be + // printed by debugDumpCode. + maxCodeDebugBytes = 128 ) // Infof logs an formatted info message by calling log.Infof. @@ -61,6 +64,7 @@ func (t *Task) IsLogging(level log.Level) bool { func (t *Task) DebugDumpState() { t.debugDumpRegisters() t.debugDumpStack() + t.debugDumpCode() if mm := t.MemoryManager(); mm != nil { t.Debugf("Mappings:\n%s", mm) } @@ -128,6 +132,45 @@ func (t *Task) debugDumpStack() { } } +// debugDumpCode logs user code contents at log level debug. +// +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) debugDumpCode() { + if !t.IsLogging(log.Debug) { + return + } + m := t.MemoryManager() + if m == nil { + t.Debugf("Memory manager for task is gone, skipping application code dump.") + return + } + t.Debugf("Code:") + // Print code on both sides of the instruction register. + start := usermem.Addr(t.Arch().IP()) - maxCodeDebugBytes/2 + // Round addr down to a 16-byte boundary. + start &= ^usermem.Addr(15) + // Print 16 bytes per line, one byte at a time. + for offset := uint64(0); offset < maxCodeDebugBytes; offset += 16 { + addr, ok := start.AddLength(offset) + if !ok { + break + } + var data [16]byte + n, err := m.CopyIn(t, addr, data[:], usermem.IOOpts{ + IgnorePermissions: true, + }) + // Print as much of the line as we can, even if an error was + // encountered. + if n > 0 { + t.Debugf("%x: % x", addr, data[:n]) + } + if err != nil { + t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err) + break + } + } +} + // trace definitions. // // Note that all region names are prefixed by ':' in order to ensure that they @@ -203,6 +246,6 @@ func (t *Task) traceExecEvent(tc *TaskContext) { trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") return } - defer file.DecRef() + defer file.DecRef(t) trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t)) } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index d654dd997..8dc3fec90 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -26,6 +26,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -140,7 +141,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error { region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) - _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) + _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found) if err == nil && bytes.Equal(expected, found) { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) @@ -167,15 +168,30 @@ func (app *runApp) execute(t *Task) taskRunState { return (*runInterrupt)(nil) } - // We're about to switch to the application again. If there's still a + // Execute any task work callbacks before returning to user space. + if atomic.LoadInt32(&t.taskWorkCount) > 0 { + t.taskWorkMu.Lock() + queue := t.taskWork + t.taskWork = nil + atomic.StoreInt32(&t.taskWorkCount, 0) + t.taskWorkMu.Unlock() + + // Do not hold taskWorkMu while executing task work, which may register + // more work. + for _, work := range queue { + work.TaskWork(t) + } + } + + // We're about to switch to the application again. If there's still an // unhandled SyscallRestartErrno that wasn't translated to an EINTR, // restart the syscall that was interrupted. If there's a saved signal // mask, restore it. (Note that restoring the saved signal mask may unblock // a pending signal, causing another interruption, but that signal should // not interact with the interrupted syscall.) if t.haveSyscallReturn { - if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { - if sre == ERESTART_RESTARTBLOCK { + if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { + if sre == syserror.ERESTART_RESTARTBLOCK { t.Debugf("Restarting syscall %d with restart block after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre) t.Arch().RestartSyscallWithRestartBlock() } else { @@ -245,7 +261,7 @@ func (app *runApp) execute(t *Task) taskRunState { region := trace.StartRegion(t.traceContext, runRegion) t.accountTaskGoroutineEnter(TaskGoroutineRunningApp) - info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU) + info, at, err := t.p.Switch(t, t.MemoryManager(), t.Arch(), t.rseqCPU) t.accountTaskGoroutineLeave(TaskGoroutineRunningApp) region.End() diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 09366b60c..52c55d13d 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -133,9 +133,10 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) { } } -// Preconditions: The caller must be running on the task goroutine, and leaving -// a state indicated by a previous call to -// t.accountTaskGoroutineEnter(state). +// Preconditions: +// * The caller must be running on the task goroutine +// * The caller must be leaving a state indicated by a previous call to +// t.accountTaskGoroutineEnter(state). func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { if state != TaskGoroutineRunningApp { // Task is unblocking/continuing. @@ -191,8 +192,8 @@ func (tg *ThreadGroup) CPUStats() usage.CPUStats { return tg.cpuStatsAtLocked(tg.leader.k.CPUClockNow()) } -// Preconditions: As for TaskGoroutineSchedInfo.userTicksAt. The TaskSet mutex -// must be locked. +// Preconditions: Same as TaskGoroutineSchedInfo.userTicksAt, plus: +// * The TaskSet mutex must be locked. func (tg *ThreadGroup) cpuStatsAtLocked(now uint64) usage.CPUStats { stats := tg.exitedCPUStats // Account for live tasks. diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 79766cafe..ebdb83061 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -159,7 +159,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS sigact := computeAction(linux.Signal(info.Signo), act) if t.haveSyscallReturn { - if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { + if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { // Signals that are ignored, cause a thread group stop, or // terminate the thread group do not interact with interrupted // syscalls; in Linux terms, they are never returned to the signal @@ -168,11 +168,11 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS // signal that is actually handled (by userspace). if sigact == SignalActionHandler { switch { - case sre == ERESTARTNOHAND: + case sre == syserror.ERESTARTNOHAND: fallthrough - case sre == ERESTART_RESTARTBLOCK: + case sre == syserror.ERESTART_RESTARTBLOCK: fallthrough - case (sre == ERESTARTSYS && !act.IsRestart()): + case (sre == syserror.ERESTARTSYS && !act.IsRestart()): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) default: @@ -255,10 +255,15 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) } } + mm := t.MemoryManager() // Set up the signal handler. If we have a saved signal mask, the signal // handler should run with the current mask, but sigreturn should restore // the saved one. - st := &arch.Stack{t.Arch(), t.MemoryManager(), sp} + st := &arch.Stack{ + Arch: t.Arch(), + IO: mm, + Bottom: sp, + } mask := t.signalMask if t.haveSavedSignalMask { mask = t.savedSignalMask @@ -273,12 +278,13 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Please see the linux code as reference: // linux/arch/arm64/kernel/signal.c:setup_return() if act.Flags&linux.SA_RESTORER == 0 { - act.Restorer = t.MemoryManager().VDSOSigReturn() + act.Restorer = mm.VDSOSigReturn() } if err := t.Arch().SignalSetup(st, &act, info, &alt, mask); err != nil { return err } + t.p.FullStateChanged() t.haveSavedSignalMask = false // Add our signal mask. @@ -310,14 +316,16 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) { // Restore our signal mask. SIGKILL and SIGSTOP should not be blocked. t.SetSignalMask(sigset &^ UnblockableSignals) + t.p.FullStateChanged() return ctrlResume, nil } // Sigtimedwait implements the semantics of sigtimedwait(2). // -// Preconditions: The caller must be running on the task goroutine. t.exitState -// < TaskExitZombie. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t.exitState < TaskExitZombie. func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) { // set is the set of signals we're interested in; invert it to get the set // of signals to block. @@ -581,8 +589,9 @@ func (t *Task) SignalMask() linux.SignalSet { // SetSignalMask sets t's signal mask. // -// Preconditions: SetSignalMask can only be called by the task goroutine. -// t.exitState < TaskExitZombie. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t.exitState < TaskExitZombie. func (t *Task) SetSignalMask(mask linux.SignalSet) { // By precondition, t prevents t.tg from completing an execve and mutating // t.tg.signalHandlers, so we can skip the TaskSet mutex. @@ -628,7 +637,7 @@ func (t *Task) setSignalMaskLocked(mask linux.SignalSet) { // SetSavedSignalMask sets the saved signal mask (see Task.savedSignalMask's // comment). // -// Preconditions: SetSavedSignalMask can only be called by the task goroutine. +// Preconditions: The caller must be running on the task goroutine. func (t *Task) SetSavedSignalMask(mask linux.SignalSet) { t.savedSignalMask = mask t.haveSavedSignalMask = true @@ -636,6 +645,7 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) { // SignalStack returns the task-private signal stack. func (t *Task) SignalStack() arch.SignalStack { + t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) alt := t.signalStack if t.onSignalStack(alt) { alt.Flags |= arch.SignalStackFlagOnStack @@ -1050,6 +1060,8 @@ func (*runInterrupt) execute(t *Task) taskRunState { // Are there signals pending? if info := t.dequeueSignalLocked(t.signalMask); info != nil { + t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) + if linux.SignalSetOf(linux.Signal(info.Signo))&StopSignals != 0 { // Indicate that we've dequeued a stop signal before unlocking the // signal mutex; initiateGroupStop will check for races with diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 8485fb4b6..64c1e120a 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -102,10 +102,10 @@ func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) { t, err := ts.newTask(cfg) if err != nil { cfg.TaskContext.release() - cfg.FSContext.DecRef() - cfg.FDTable.DecRef() + cfg.FSContext.DecRef(t) + cfg.FDTable.DecRef(t) if cfg.MountNamespaceVFS2 != nil { - cfg.MountNamespaceVFS2.DecRef() + cfg.MountNamespaceVFS2.DecRef(t) } return nil, err } diff --git a/pkg/sentry/kernel/task_stop.go b/pkg/sentry/kernel/task_stop.go index 10c6e455c..a35948a5f 100644 --- a/pkg/sentry/kernel/task_stop.go +++ b/pkg/sentry/kernel/task_stop.go @@ -99,8 +99,9 @@ type TaskStop interface { // beginInternalStop indicates the start of an internal stop that applies to t. // -// Preconditions: The task must not already be in an internal stop (i.e. t.stop -// == nil). The caller must be running on the task goroutine. +// Preconditions: +// * The caller must be running on the task goroutine. +// * The task must not already be in an internal stop (i.e. t.stop == nil). func (t *Task) beginInternalStop(s TaskStop) { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() @@ -109,8 +110,8 @@ func (t *Task) beginInternalStop(s TaskStop) { t.beginInternalStopLocked(s) } -// Preconditions: The signal mutex must be locked. All preconditions for -// Task.beginInternalStop also apply. +// Preconditions: Same as beginInternalStop, plus: +// * The signal mutex must be locked. func (t *Task) beginInternalStopLocked(s TaskStop) { if t.stop != nil { panic(fmt.Sprintf("Attempting to enter internal stop %#v when already in internal stop %#v", s, t.stop)) @@ -128,8 +129,9 @@ func (t *Task) beginInternalStopLocked(s TaskStop) { // t.stop, which is why there is no endInternalStop that locks the signal mutex // for you. // -// Preconditions: The signal mutex must be locked. The task must be in an -// internal stop (i.e. t.stop != nil). +// Preconditions: +// * The signal mutex must be locked. +// * The task must be in an internal stop (i.e. t.stop != nil). func (t *Task) endInternalStopLocked() { if t.stop == nil { panic("Attempting to leave non-existent internal stop") @@ -205,6 +207,22 @@ func (ts *TaskSet) BeginExternalStop() { } } +// PullFullState receives full states for all tasks. +func (ts *TaskSet) PullFullState() { + ts.mu.Lock() + defer ts.mu.Unlock() + if ts.Root == nil { + return + } + for t := range ts.Root.tids { + t.Activate() + if mm := t.MemoryManager(); mm != nil { + t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) + } + t.Deactivate() + } +} + // EndExternalStop indicates the end of an external stop started by a previous // call to TaskSet.BeginExternalStop. EndExternalStop does not wait for task // goroutines to resume. diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index a5903b0b5..0141459e7 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -29,75 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel -// include/linux/errno.h. These errnos are never returned to userspace -// directly, but are used to communicate the expected behavior of an -// interrupted syscall from the syscall to signal handling. -type SyscallRestartErrno int - -// These numeric values are significant because ptrace syscall exit tracing can -// observe them. -// -// For all of the following errnos, if the syscall is not interrupted by a -// signal delivered to a user handler, the syscall is restarted. -const ( - // ERESTARTSYS is returned by an interrupted syscall to indicate that it - // should be converted to EINTR if interrupted by a signal delivered to a - // user handler without SA_RESTART set, and restarted otherwise. - ERESTARTSYS = SyscallRestartErrno(512) - - // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it - // should always be restarted. - ERESTARTNOINTR = SyscallRestartErrno(513) - - // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it - // should be converted to EINTR if interrupted by a signal delivered to a - // user handler, and restarted otherwise. - ERESTARTNOHAND = SyscallRestartErrno(514) - - // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate - // that it should be restarted using a custom function. The interrupted - // syscall must register a custom restart function by calling - // Task.SetRestartSyscallFn. - ERESTART_RESTARTBLOCK = SyscallRestartErrno(516) -) - var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application") -// Error implements error.Error. -func (e SyscallRestartErrno) Error() string { - // Descriptions are borrowed from strace. - switch e { - case ERESTARTSYS: - return "to be restarted if SA_RESTART is set" - case ERESTARTNOINTR: - return "to be restarted" - case ERESTARTNOHAND: - return "to be restarted if no handler" - case ERESTART_RESTARTBLOCK: - return "interrupted by signal" - default: - return "(unknown interrupt error)" - } -} - -// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by -// rv, the value in a syscall return register. -func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) { - switch int(rv) { - case -int(ERESTARTSYS): - return ERESTARTSYS, true - case -int(ERESTARTNOINTR): - return ERESTARTNOINTR, true - case -int(ERESTARTNOHAND): - return ERESTARTNOHAND, true - case -int(ERESTART_RESTARTBLOCK): - return ERESTART_RESTARTBLOCK, true - default: - return 0, false - } -} - // SyscallRestartBlock represents the restart block for a syscall restartable // with a custom function. It encapsulates the state required to restart a // syscall across a S/R. @@ -354,7 +288,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { // Grab the caller up front, to make sure there's a sensible stack. caller := t.Arch().Native(uintptr(0)) - if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil { + if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil { t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) @@ -390,7 +324,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { type runVsyscallAfterPtraceEventSeccomp struct { addr usermem.Addr sysno uintptr - caller interface{} + caller marshal.Marshallable } func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { @@ -413,7 +347,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller) } -func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState { +func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller marshal.Marshallable) taskRunState { rval, ctrl, err := t.executeSyscall(sysno, args) if ctrl != nil { t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl) @@ -447,7 +381,7 @@ func ExtractErrno(err error, sysno int) int { return 0 case syscall.Errno: return int(err) - case SyscallRestartErrno: + case syserror.SyscallRestartErrno: return int(err) case *memmap.BusError: // Bus errors may generate SIGBUS, but for syscalls they still diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index b02044ad2..ce134bf54 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -43,17 +44,6 @@ func (t *Task) Deactivate() { } } -// CopyIn copies a fixed-size value or slice of fixed-size values in from the -// task's memory. The copy will fail with syscall.EFAULT if it traverses user -// memory that is unmapped or not readable by the user. -// -// This Task's AddressSpace must be active. -func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) { - return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{ - AddressSpaceActive: true, - }) -} - // CopyInBytes is a fast version of CopyIn if the caller can serialize the // data without reflection and pass in a byte slice. // @@ -64,17 +54,6 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { }) } -// CopyOut copies a fixed-size value or slice of fixed-size values out to the -// task's memory. The copy will fail with syscall.EFAULT if it traverses user -// memory that is unmapped or not writeable by the user. -// -// This Task's AddressSpace must be active. -func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) { - return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{ - AddressSpaceActive: true, - }) -} - // CopyOutBytes is a fast version of CopyOut if the caller can serialize the // data without reflection and pass in a byte slice. // @@ -114,7 +93,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ var v []string for { argAddr := t.Arch().Native(0) - if _, err := t.CopyIn(addr, argAddr); err != nil { + if _, err := argAddr.CopyIn(t, addr); err != nil { return v, err } if t.Arch().Value(argAddr) == 0 { @@ -143,8 +122,9 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ // CopyOutIovecs converts src to an array of struct iovecs and copies it to the // memory mapped at addr. // -// Preconditions: As for usermem.IO.CopyOut. The caller must be running on the -// task goroutine. t's AddressSpace must be active. +// Preconditions: Same as usermem.IO.CopyOut, plus: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error { switch t.Arch().Width() { case 8: @@ -191,8 +171,9 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error // combined length of all AddrRanges would otherwise exceed this amount, ranges // beyond MAX_RW_COUNT are silently truncated. // -// Preconditions: As for usermem.IO.CopyIn. The caller must be running on the -// task goroutine. t's AddressSpace must be active. +// Preconditions: Same as usermem.IO.CopyIn, plus: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRangeSeq, error) { if numIovecs == 0 { return usermem.AddrRangeSeq{}, nil @@ -284,7 +265,7 @@ func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOp // // IovecsIOSequence is analogous to Linux's lib/iov_iter.c:import_iovec(). // -// Preconditions: As for Task.CopyInIovecs. +// Preconditions: Same as Task.CopyInIovecs. func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) { if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV { return usermem.IOSequence{}, syserror.EINVAL @@ -299,3 +280,30 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp Opts: opts, }, nil } + +// copyContext implements marshal.CopyContext. It wraps a task to allow copying +// memory to and from the task memory with custom usermem.IOOpts. +type copyContext struct { + *Task + opts usermem.IOOpts +} + +// AsCopyContext wraps the task and returns it as CopyContext. +func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { + return ©Context{t, opts} +} + +// CopyInString copies a string in from the task's memory. +func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { + return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) +} + +// CopyInBytes copies task memory into dst from an IO context. +func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + return t.MemoryManager().CopyIn(t, addr, dst, t.opts) +} + +// CopyOutBytes copies src into task memoryfrom an IO context. +func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + return t.MemoryManager().CopyOut(t, addr, src, t.opts) +} diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go new file mode 100644 index 000000000..dda5a433a --- /dev/null +++ b/pkg/sentry/kernel/task_work.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync/atomic" + +// TaskWorker is a deferred task. +// +// This must be savable. +type TaskWorker interface { + // TaskWork will be executed prior to returning to user space. Note that + // TaskWork may call RegisterWork again, but this will not be executed until + // the next return to user space, unlike in Linux. This effectively allows + // registration of indefinite user return hooks, but not by default. + TaskWork(t *Task) +} + +// RegisterWork can be used to register additional task work that will be +// performed prior to returning to user space. See TaskWorker.TaskWork for +// semantics regarding registration. +func (t *Task) RegisterWork(work TaskWorker) { + t.taskWorkMu.Lock() + defer t.taskWorkMu.Unlock() + atomic.AddInt32(&t.taskWorkCount, 1) + t.taskWork = append(t.taskWork, work) +} diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 4dfd2c990..0b34c0099 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -308,7 +308,7 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet { } // release releases the thread group's resources. -func (tg *ThreadGroup) release() { +func (tg *ThreadGroup) release(t *Task) { // Timers must be destroyed without holding the TaskSet or signal mutexes // since timers send signals with Timer.mu locked. tg.itimerRealTimer.Destroy() @@ -325,7 +325,7 @@ func (tg *ThreadGroup) release() { it.DestroyTimer() } if tg.mounts != nil { - tg.mounts.DecRef() + tg.mounts.DecRef(t) } } diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 872e1a82d..5ae5906e8 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -36,6 +36,8 @@ import ( const TasksLimit = (1 << 16) // ThreadID is a generic thread identifier. +// +// +marshal type ThreadID int32 // String returns a decimal representation of the ThreadID. diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 7ba7dc50c..2817aa3ba 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -6,6 +6,7 @@ go_library( name = "time", srcs = [ "context.go", + "tcpip.go", "time.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go new file mode 100644 index 000000000..c4474c0cf --- /dev/null +++ b/pkg/sentry/kernel/time/tcpip.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package time + +import ( + "sync" + "time" +) + +// TcpipAfterFunc waits for duration to elapse according to clock then runs fn. +// The timer is started immediately and will fire exactly once. +func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer { + timer := &TcpipTimer{ + clock: clock, + } + timer.notifier = functionNotifier{ + fn: func() { + // tcpip.Timer.Stop() explicitly states that the function is called in a + // separate goroutine that Stop() does not synchronize with. + // Timer.Destroy() synchronizes with calls to TimerListener.Notify(). + // This is semantically meaningful because, in the former case, it's + // legal to call tcpip.Timer.Stop() while holding locks that may also be + // taken by the function, but this isn't so in the latter case. Most + // immediately, Timer calls TimerListener.Notify() while holding + // Timer.mu. A deadlock occurs without spawning a goroutine: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock! + // + // Spawning a goroutine avoids the deadlock: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() <- Launches T2 + // T2: + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, blocks + // T1: + // => (returns) <- Timer.mu.Unlock() called + // T2: + // => (continues) <- No deadlock! + go func() { + timer.Stop() + fn() + }() + }, + } + timer.Reset(duration) + return timer +} + +// TcpipTimer is a resettable timer with variable duration expirations. +// Implements tcpip.Timer, which does not define a Destroy method; instead, all +// resources are released after timer expiration and calls to Timer.Stop. +// +// Must be created by AfterFunc. +type TcpipTimer struct { + // clock is the time source. clock is immutable. + clock Clock + + // notifier is called when the Timer expires. notifier is immutable. + notifier functionNotifier + + // mu protects t. + mu sync.Mutex + + // t stores the latest running Timer. This is replaced whenever Reset is + // called since Timer cannot be restarted once it has been Destroyed by Stop. + // + // This field is nil iff Stop has been called. + t *Timer +} + +// Stop implements tcpip.Timer.Stop. +func (r *TcpipTimer) Stop() bool { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + return false + } + _, lastSetting := r.t.Swap(Setting{}) + r.t.Destroy() + r.t = nil + return lastSetting.Enabled +} + +// Reset implements tcpip.Timer.Reset. +func (r *TcpipTimer) Reset(d time.Duration) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + r.t = NewTimer(r.clock, &r.notifier) + } + + r.t.Swap(Setting{ + Enabled: true, + Period: 0, + Next: r.clock.Now().Add(d), + }) +} + +// functionNotifier is a TimerListener that runs a function. +// +// functionNotifier cannot be saved or loaded. +type functionNotifier struct { + fn func() +} + +// Notify implements ktime.TimerListener.Notify. +func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) { + f.fn() + return Setting{}, false +} + +// Destroy implements ktime.TimerListener.Destroy. +func (f *functionNotifier) Destroy() {} diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index e959700f2..f61a8e164 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -616,8 +616,10 @@ func (t *Timer) Swap(s Setting) (Time, Setting) { // Timer's Clock) at which the Setting was changed. Setting s.Enabled to true // starts the timer, while setting s.Enabled to false stops it. // -// Preconditions: The Timer must not be paused. f cannot call any Timer methods -// since it is called with the Timer mutex locked. +// Preconditions: +// * The Timer must not be paused. +// * f cannot call any Timer methods since it is called with the Timer mutex +// locked. func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { now := t.clock.Now() t.mu.Lock() diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 0adf25691..7c4fefb16 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/log" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -90,7 +90,7 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) { +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { return &Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), }, nil @@ -210,9 +210,6 @@ func (t *Timekeeper) startUpdater() { p.realtimeBaseRef = int64(realtimeParams.BaseRef) p.realtimeFrequency = realtimeParams.Frequency } - - log.Debugf("Updating VDSO parameters: %+v", p) - return p }); err != nil { log.Warningf("Unable to update VDSO parameter page: %v", err) diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index f1b3c212c..9bc452e67 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -17,10 +17,9 @@ package kernel import ( "fmt" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -28,6 +27,8 @@ import ( // // They are exposed to the VDSO via a parameter page managed by VDSOParamPage, // which also includes a sequence counter. +// +// +marshal type vdsoParams struct { monotonicReady uint64 monotonicBaseCycles int64 @@ -58,7 +59,7 @@ type vdsoParams struct { type VDSOParamPage struct { // The parameter page is fr, allocated from mfp.MemoryFile(). mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange // seq is the current sequence count written to the page. // @@ -68,21 +69,29 @@ type VDSOParamPage struct { // checked in state_test_util tests, causing this field to change across // save / restore. seq uint64 + + // copyScratchBuffer is a temporary buffer used to marshal the params before + // copying it to the real parameter page. The parameter page is typically + // updated at a moderate frequency of ~O(seconds) throughout the lifetime of + // the sentry, so reusing this buffer is a good tradeoff between memory + // usage and the cost of allocation. + copyScratchBuffer []byte } // NewVDSOParamPage returns a VDSOParamPage. // // Preconditions: -// // * fr is a single page allocated from mfp.MemoryFile(). VDSOParamPage does // not take ownership of fr; it must remain allocated for the lifetime of the // VDSOParamPage. -// // * VDSOParamPage must be the only writer to fr. -// // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. -func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage { - return &VDSOParamPage{mfp: mfp, fr: fr} +func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { + return &VDSOParamPage{ + mfp: mfp, + fr: fr, + copyScratchBuffer: make([]byte, (*vdsoParams)(nil).SizeBytes()), + } } // access returns a mapping of the param page. @@ -136,7 +145,8 @@ func (v *VDSOParamPage) Write(f func() vdsoParams) error { // Get the new params. p := f() - buf := binary.Marshal(nil, usermem.ByteOrder, p) + buf := v.copyScratchBuffer[:p.SizeBytes()] + p.MarshalUnsafe(buf) // Skip the sequence counter. if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil { diff --git a/pkg/sentry/limits/context.go b/pkg/sentry/limits/context.go index 77e1fe217..0bade6e57 100644 --- a/pkg/sentry/limits/context.go +++ b/pkg/sentry/limits/context.go @@ -33,3 +33,12 @@ func FromContext(ctx context.Context) *LimitSet { } return nil } + +// FromContextOrDie returns FromContext(ctx) if the latter is not nil. +// Otherwise, panic is triggered. +func FromContextOrDie(ctx context.Context) *LimitSet { + if v := ctx.Value(CtxLimits); v != nil { + return v.(*LimitSet) + } + panic("failed to create limit set from context") +} diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index c6aa65f28..34bdb0b69 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -30,9 +30,6 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsbridge", "//pkg/sentry/kernel/auth", "//pkg/sentry/limits", @@ -45,6 +42,5 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/usermem", - "//pkg/waiter", ], ) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 616fafa2c..d4610ec3b 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -90,14 +90,23 @@ type elfInfo struct { sharedObject bool } +// fullReader interface extracts the ReadFull method from fsbridge.File so that +// client code does not need to define an entire fsbridge.File when only read +// functionality is needed. +// +// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap +// vfs.FileDescription's PRead/Read instead. +type fullReader interface { + // ReadFull is the same as fsbridge.File.ReadFull. + ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) +} + // parseHeader parse the ELF header, verifying that this is a supported ELF // file and returning the ELF program headers. // // This is similar to elf.NewFile, except that it is more strict about what it // accepts from the ELF, and it doesn't parse unnecessary parts of the file. -// -// ctx may be nil if f does not need it. -func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) { +func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // Check ident first; it will tell us the endianness of the rest of the // structs. var ident [elf.EI_NIDENT]byte @@ -272,7 +281,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr } defer func() { if mopts.MappingIdentity != nil { - mopts.MappingIdentity.DecRef() + mopts.MappingIdentity.DecRef(ctx) } }() if err := f.ConfigureMMap(ctx, &mopts); err != nil { @@ -393,8 +402,7 @@ type loadedELF struct { // // It does not load the ELF interpreter, or return any auxv entries. // -// Preconditions: -// * f is an ELF file +// Preconditions: f is an ELF file. func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) { first := true var start, end usermem.Addr @@ -562,8 +570,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in // It does not load the ELF interpreter, or return any auxv entries. // // Preconditions: -// * f is an ELF file -// * f is the first ELF loaded into m +// * f is an ELF file. +// * f is the first ELF loaded into m. func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f fsbridge.File) (loadedELF, arch.Context, error) { info, err := parseHeader(ctx, f) if err != nil { @@ -600,8 +608,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS // // It does not return any auxv entries. // -// Preconditions: -// * f is an ELF file +// Preconditions: f is an ELF file. func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) { info, err := parseHeader(ctx, f) if err != nil { @@ -631,8 +638,7 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.Fil // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // -// Preconditions: -// * args.File is an ELF file +// Preconditions: args.File is an ELF file. func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { @@ -654,7 +660,7 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err } - defer intFile.DecRef() + defer intFile.DecRef(ctx) interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin) if err != nil { diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 88449fe95..c69b62db9 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -80,22 +79,6 @@ type LoadArgs struct { Features *cpuid.FeatureSet } -// readFull behaves like io.ReadFull for an *fs.File. -func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var total int64 - for dst.NumBytes() > 0 { - n, err := f.Preadv(ctx, dst, offset+total) - total += n - if err == io.EOF && total != 0 { - return total, io.ErrUnexpectedEOF - } else if err != nil { - return total, err - } - dst = dst.DropFirst64(n) - } - return total, nil -} - // openPath opens args.Filename and checks that it is valid for loading. // // openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not @@ -139,7 +122,7 @@ func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch if err != nil { return nil, err } - return &arch.Stack{a, m, ar.End}, nil + return &arch.Stack{Arch: a, IO: m, Bottom: ar.End}, nil } const ( @@ -171,7 +154,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context return loadedELF{}, nil, nil, nil, err } // Ensure file is release in case the code loops or errors out. - defer args.File.DecRef() + defer args.File.DecRef(ctx) } else { if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil { return loadedELF{}, nil, nil, nil, err @@ -232,20 +215,20 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context // path and argv. // // Preconditions: -// * The Task MemoryManager is empty. -// * Load is called on the Task goroutine. +// * The Task MemoryManager is empty. +// * Load is called on the Task goroutine. func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { // Load the executable itself. loaded, ac, file, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } - defer file.DecRef() + defer file.DecRef(ctx) // Load the VDSO. vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the @@ -264,20 +247,20 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(args.Filename) - if err != nil { + if _, err := stack.PushNullTerminatedByteSlice([]byte(args.Filename)); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } + execfn := stack.Bottom // Push 16 random bytes on the stack which AT_RANDOM will point to. var b [16]byte if _, err := rand.Read(b[:]); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux()) } - random, err := stack.Push(b) - if err != nil { + if _, err = stack.PushNullTerminatedByteSlice(b[:]); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux()) } + random := stack.Bottom c := auth.CredentialsFromContext(ctx) @@ -309,7 +292,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V m.SetEnvvStart(sl.EnvvStart) m.SetEnvvEnd(sl.EnvvEnd) m.SetAuxv(auxv) - m.SetExecutable(file) + m.SetExecutable(ctx, file) symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn") if err != nil { diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index 165869028..05a294fe6 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -26,10 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -37,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) const vdsoPrelink = 0xffffffffff700000 @@ -55,52 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} { } } -// byteReader implements fs.FileOperations for reading from a []byte source. -type byteReader struct { - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` - +type byteFullReader struct { data []byte } -var _ fs.FileOperations = (*byteReader)(nil) - -// newByteReaderFile creates a fake file to read data from. -// -// TODO(gvisor.dev/issue/2921): Convert to VFS2. -func newByteReaderFile(ctx context.Context, data []byte) *fs.File { - // Create a fake inode. - inode := fs.NewInode( - ctx, - &fsutil.SimpleFileInode{}, - fs.NewPseudoMountSource(ctx), - fs.StableAttr{ - Type: fs.Anonymous, - DeviceID: anon.PseudoDevice.DeviceID(), - InodeID: anon.PseudoDevice.NextIno(), - BlockSize: usermem.PageSize, - }) - - // Use the fake inode to create a fake dirent. - dirent := fs.NewTransientDirent(inode) - defer dirent.DecRef() - - // Use the fake dirent to make a fake file. - flags := fs.FileFlags{Read: true, Pread: true} - return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{ - data: data, - }) -} - -func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { +func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } @@ -111,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ return int64(n), err } -func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { - panic("Write not supported") -} - // validateVDSO checks that the VDSO can be loaded by loadVDSO. // // VDSOs are special (see below). Since we are going to map the VDSO directly @@ -130,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq // * PT_LOAD segments don't extend beyond the end of the file. // // ctx may be nil if f does not need it. -func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) { +func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) { info, err := parseHeader(ctx, f) if err != nil { log.Infof("Unable to parse VDSO header: %v", err) @@ -248,13 +198,12 @@ func getSymbolValueFromVDSO(symbol string) (uint64, error) { // PrepareVDSO validates the system VDSO and returns a VDSO, containing the // param page for updating by the kernel. -func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) { - vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin)) +func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { + vdsoFile := &byteFullReader{data: vdsoBin} // First make sure the VDSO is valid. vdsoFile does not use ctx, so a // nil context can be passed. info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin))) - vdsoFile.DecRef() if err != nil { return nil, err } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index a98b66de1..2c95669cd 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -28,9 +28,21 @@ go_template_instance( }, ) +go_template_instance( + name = "file_range", + out = "file_range.go", + package = "memmap", + prefix = "File", + template = "//pkg/segment:generic_range", + types = { + "T": "uint64", + }, +) + go_library( name = "memmap", srcs = [ + "file_range.go", "mappable_range.go", "mapping_set.go", "mapping_set_impl.go", @@ -40,7 +52,7 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/sentry/platform", + "//pkg/safemem", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go index d609c1ae0..457ed87f8 100644 --- a/pkg/sentry/memmap/mapping_set.go +++ b/pkg/sentry/memmap/mapping_set.go @@ -177,7 +177,7 @@ func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr // AddMapping adds the given mapping and returns the set of MappableRanges that // previously had no mappings. // -// Preconditions: As for Mappable.AddMapping. +// Preconditions: Same as Mappable.AddMapping. func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange { mr := MappableRange{offset, offset + uint64(ar.Length())} var mapped []MappableRange @@ -204,7 +204,7 @@ func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset ui // RemoveMapping removes the given mapping and returns the set of // MappableRanges that now have no mappings. // -// Preconditions: As for Mappable.RemoveMapping. +// Preconditions: Same as Mappable.RemoveMapping. func (s *MappingSet) RemoveMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange { mr := MappableRange{offset, offset + uint64(ar.Length())} var unmapped []MappableRange diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index c6db9fc8f..a44fa2b95 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -19,18 +19,18 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 -// offsets to (platform.File, uint64 File offset) pairs. +// offsets to (File, uint64 File offset) pairs. // // See mm/mm.go for Mappable's place in the lock order. // -// Preconditions: For all Mappable methods, usermem.AddrRanges and -// MappableRanges must be non-empty (Length() != 0), and usermem.Addrs and -// Mappable offsets must be page-aligned. +// All Mappable methods have the following preconditions: +// * usermem.AddrRanges and MappableRanges must be non-empty (Length() != 0). +// * usermem.Addrs and Mappable offsets must be page-aligned. type Mappable interface { // AddMapping notifies the Mappable of a mapping from addresses ar in ms to // offsets [offset, offset+ar.Length()) in this Mappable. @@ -48,8 +48,10 @@ type Mappable interface { // addresses ar in ms to offsets [offset, offset+ar.Length()) in this // Mappable. // - // Preconditions: offset+ar.Length() does not overflow. The removed mapping - // must exist. writable must match the corresponding call to AddMapping. + // Preconditions: + // * offset+ar.Length() does not overflow. + // * The removed mapping must exist. writable must match the + // corresponding call to AddMapping. RemoveMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) // CopyMapping notifies the Mappable of an attempt to copy a mapping in ms @@ -60,9 +62,10 @@ type Mappable interface { // CopyMapping is only called when a mapping is copied within a given // MappingSpace; it is analogous to Linux's vm_operations_struct::mremap. // - // Preconditions: offset+srcAR.Length() and offset+dstAR.Length() do not - // overflow. The mapping at srcAR must exist. writable must match the - // corresponding call to AddMapping. + // Preconditions: + // * offset+srcAR.Length() and offset+dstAR.Length() do not overflow. + // * The mapping at srcAR must exist. writable must match the + // corresponding call to AddMapping. CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error // Translate returns the Mappable's current mappings for at least the range @@ -74,14 +77,17 @@ type Mappable interface { // Translations are valid until invalidated by a callback to // MappingSpace.Invalidate or until the caller removes its mapping of the // translated range. Mappable implementations must ensure that at least one - // reference is held on all pages in a platform.File that may be the result + // reference is held on all pages in a File that may be the result // of a valid Translation. // - // Preconditions: required.Length() > 0. optional.IsSupersetOf(required). - // required and optional must be page-aligned. The caller must have - // established a mapping for all of the queried offsets via a previous call - // to AddMapping. The caller is responsible for ensuring that calls to - // Translate synchronize with invalidation. + // Preconditions: + // * required.Length() > 0. + // * optional.IsSupersetOf(required). + // * required and optional must be page-aligned. + // * The caller must have established a mapping for all of the queried + // offsets via a previous call to AddMapping. + // * The caller is responsible for ensuring that calls to Translate + // synchronize with invalidation. // // Postconditions: See CheckTranslateResult. Translate(ctx context.Context, required, optional MappableRange, at usermem.AccessType) ([]Translation, error) @@ -100,7 +106,7 @@ type Translation struct { Source MappableRange // File is the mapped file. - File platform.File + File File // Offset is the offset into File at which this Translation begins. Offset uint64 @@ -110,15 +116,15 @@ type Translation struct { Perms usermem.AccessType } -// FileRange returns the platform.FileRange represented by t. -func (t Translation) FileRange() platform.FileRange { - return platform.FileRange{t.Offset, t.Offset + t.Source.Length()} +// FileRange returns the FileRange represented by t. +func (t Translation) FileRange() FileRange { + return FileRange{t.Offset, t.Offset + t.Source.Length()} } // CheckTranslateResult returns an error if (ts, terr) does not satisfy all // postconditions for Mappable.Translate(required, optional, at). // -// Preconditions: As for Mappable.Translate. +// Preconditions: Same as Mappable.Translate. func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error { // Verify that the inputs to Mappable.Translate were valid. if !required.WellFormed() || required.Length() <= 0 { @@ -214,7 +220,9 @@ type MappingSpace interface { // Invalidate must not take any locks preceding mm.MemoryManager.activeMu // in the lock order. // - // Preconditions: ar.Length() != 0. ar must be page-aligned. + // Preconditions: + // * ar.Length() != 0. + // * ar must be page-aligned. Invalidate(ar usermem.AddrRange, opts InvalidateOpts) } @@ -238,7 +246,7 @@ type MappingIdentity interface { IncRef() // DecRef decrements the MappingIdentity's reference count. - DecRef() + DecRef(ctx context.Context) // MappedName returns the application-visible name shown in // /proc/[pid]/maps. @@ -360,4 +368,62 @@ type MMapOpts struct { // // TODO(jamieliu): Replace entirely with MappingIdentity? Hint string + + // Force means to skip validation checks of Addr and Length. It can be + // used to create special mappings below mm.layout.MinAddr and + // mm.layout.MaxAddr. It has to be used with caution. + // + // If Force is true, Unmap and Fixed must be true. + Force bool +} + +// File represents a host file that may be mapped into an platform.AddressSpace. +type File interface { + // All pages in a File are reference-counted. + + // IncRef increments the reference count on all pages in fr. + // + // Preconditions: + // * fr.Start and fr.End must be page-aligned. + // * fr.Length() > 0. + // * At least one reference must be held on all pages in fr. (The File + // interface does not provide a way to acquire an initial reference; + // implementors may define mechanisms for doing so.) + IncRef(fr FileRange) + + // DecRef decrements the reference count on all pages in fr. + // + // Preconditions: + // * fr.Start and fr.End must be page-aligned. + // * fr.Length() > 0. + // * At least one reference must be held on all pages in fr. + DecRef(fr FileRange) + + // MapInternal returns a mapping of the given file offsets in the invoking + // process' address space for reading and writing. + // + // Note that fr.Start and fr.End need not be page-aligned. + // + // Preconditions: + // * fr.Length() > 0. + // * At least one reference must be held on all pages in fr. + // + // Postconditions: The returned mapping is valid as long as at least one + // reference is held on the mapped pages. + MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + + // FD returns the file descriptor represented by the File. + // + // The only permitted operation on the returned file descriptor is to map + // pages from it consistent with the requirements of AddressSpace.MapFile. + FD() int +} + +// FileRange represents a range of uint64 offsets into a File. +// +// type FileRange <generated using go_generics> + +// String implements fmt.Stringer.String. +func (fr FileRange) String() string { + return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) } diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a036ce53c..b4a47ccca 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "file_refcount_set", out = "file_refcount_set.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "mm", prefix = "fileRefcount", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "int32", "Functions": "fileRefcountSetFunctions", }, @@ -73,12 +73,35 @@ go_template_instance( }, ) +go_template_instance( + name = "aio_mappable_refs", + out = "aio_mappable_refs.go", + package = "mm", + prefix = "aioMappable", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "aioMappable", + }, +) + +go_template_instance( + name = "special_mappable_refs", + out = "special_mappable_refs.go", + package = "mm", + prefix = "SpecialMappable", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "SpecialMappable", + }, +) + go_library( name = "mm", srcs = [ "address_space.go", "aio_context.go", "aio_context_state.go", + "aio_mappable_refs.go", "debug.go", "file_refcount_set.go", "io.go", @@ -92,6 +115,7 @@ go_library( "save_restore.go", "shm.go", "special_mappable.go", + "special_mappable_refs.go", "syscalls.go", "vma.go", "vma_set.go", diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index 5c667117c..a93e76c75 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -166,8 +166,12 @@ func (mm *MemoryManager) Deactivate() { // mapASLocked maps addresses in ar into mm.as. If precommit is true, mappings // for all addresses in ar should be precommitted. // -// Preconditions: mm.activeMu must be locked. mm.as != nil. ar.Length() != 0. -// ar must be page-aligned. pseg == mm.pmas.LowerBoundSegment(ar.Start). +// Preconditions: +// * mm.activeMu must be locked. +// * mm.as != nil. +// * ar.Length() != 0. +// * ar must be page-aligned. +// * pseg == mm.pmas.LowerBoundSegment(ar.Start). func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, precommit bool) error { // By default, map entire pmas at a time, under the assumption that there // is no cost to mapping more of a pma than necessary. diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 379148903..7bf48cb2c 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -17,10 +17,8 @@ package mm import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -240,10 +238,10 @@ func (ctx *AIOContext) Drain() { // // +stateify savable type aioMappable struct { - refs.AtomicRefCount + aioMappableRefs mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange } var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp()) @@ -254,13 +252,13 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) { return nil, err } m := aioMappable{mfp: mfp, fr: fr} - m.EnableLeakCheck("mm.aioMappable") + m.EnableLeakCheck() return &m, nil } // DecRef implements refs.RefCounter.DecRef. -func (m *aioMappable) DecRef() { - m.AtomicRefCount.DecRefWithDestructor(func() { +func (m *aioMappable) DecRef(ctx context.Context) { + m.aioMappableRefs.DecRef(func() { m.mfp.MemoryFile().DecRef(m.fr) }) } @@ -368,7 +366,7 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint if err != nil { return 0, err } - defer m.DecRef() + defer m.DecRef(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ Length: aioRingBufferSize, MappingIdentity: m, diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go index fa776f9c6..a8ac48080 100644 --- a/pkg/sentry/mm/io.go +++ b/pkg/sentry/mm/io.go @@ -441,7 +441,10 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts // handleASIOFault handles a page fault at address addr for an AddressSpaceIO // operation spanning ioar. // -// Preconditions: mm.as != nil. ioar.Length() != 0. ioar.Contains(addr). +// Preconditions: +// * mm.as != nil. +// * ioar.Length() != 0. +// * ioar.Contains(addr). func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, ioar usermem.AddrRange, at usermem.AccessType) error { // Try to map all remaining pages in the I/O operation. This RoundUp can't // overflow because otherwise it would have been caught by CheckIORange. @@ -629,7 +632,9 @@ func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars userme // at most address end on AddrRange arsit.Head(). It is used in vector I/O paths to // truncate usermem.AddrRangeSeq when errors occur. // -// Preconditions: !arsit.IsEmpty(). end <= arsit.Head().End. +// Preconditions: +// * !arsit.IsEmpty(). +// * end <= arsit.Head().End. func truncatedAddrRangeSeq(ars, arsit usermem.AddrRangeSeq, end usermem.Addr) usermem.AddrRangeSeq { ar := arsit.Head() if end <= ar.Start { diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index aac56679b..09dbc06a4 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -57,6 +57,8 @@ func (mm *MemoryManager) SetMmapLayout(ac arch.Context, r *limits.LimitSet) (arc // Fork creates a copy of mm with 1 user, as for Linux syscalls fork() or // clone() (without CLONE_VM). func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { + mm.AddressSpace().PreFork() + defer mm.AddressSpace().PostFork() mm.metadataMu.Lock() defer mm.metadataMu.Unlock() mm.mappingMu.RLock() @@ -258,7 +260,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { mm.executable = nil mm.metadataMu.Unlock() if exe != nil { - exe.DecRef() + exe.DecRef(ctx) } mm.activeMu.Lock() diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index 28e5057f7..0cfd60f6c 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -15,6 +15,7 @@ package mm import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/usermem" @@ -147,7 +148,7 @@ func (mm *MemoryManager) Executable() fsbridge.File { // SetExecutable sets the executable. // // This takes a reference on d. -func (mm *MemoryManager) SetExecutable(file fsbridge.File) { +func (mm *MemoryManager) SetExecutable(ctx context.Context, file fsbridge.File) { mm.metadataMu.Lock() // Grab a new reference. @@ -164,7 +165,7 @@ func (mm *MemoryManager) SetExecutable(file fsbridge.File) { // Do this without holding the lock, since it may wind up doing some // I/O to sync the dirent, etc. if orig != nil { - orig.DecRef() + orig.DecRef(ctx) } } diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 6db7c3d40..8c9f11cce 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -25,7 +25,7 @@ // Locks taken by memmap.Mappable.Translate // mm.privateRefs.mu // platform.AddressSpace locks -// platform.File locks +// memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu // @@ -242,7 +242,7 @@ type MemoryManager struct { // +stateify savable type vma struct { // mappable is the virtual memory object mapped by this vma. If mappable is - // nil, the vma represents a private anonymous mapping. + // nil, the vma represents an anonymous mapping. mappable memmap.Mappable // off is the offset into mappable at which this vma begins. If mappable is @@ -396,7 +396,7 @@ type pma struct { // file is the file mapped by this pma. Only pmas for which file == // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to // the corresponding file range while they exist. - file platform.File `state:"nosave"` + file memmap.File `state:"nosave"` // off is the offset into file at which this pma begins. // @@ -436,7 +436,7 @@ type pma struct { private bool // If internalMappings is not empty, it is the cached return value of - // file.MapInternal for the platform.FileRange mapped by this pma. + // file.MapInternal for the memmap.FileRange mapped by this pma. internalMappings safemem.BlockSeq `state:"nosave"` } @@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 { func (fileRefcountSetFunctions) ClearValue(_ *int32) { } -func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) { +func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) { return rc1, rc1 == rc2 } -func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) { +func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) { return rc, rc } diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index fdc308542..acac3d357 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -51,7 +51,8 @@ func TestUsageASUpdates(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * usermem.PageSize, + Length: 2 * usermem.PageSize, + Private: true, }) if err != nil { t.Fatalf("MMap got err %v want nil", err) diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 62e4c20af..30facebf7 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -32,7 +31,9 @@ import ( // iterator to the pma containing ar.Start. Otherwise it returns a terminal // iterator. // -// Preconditions: mm.activeMu must be locked. ar.Length() != 0. +// Preconditions: +// * mm.activeMu must be locked. +// * ar.Length() != 0. func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -90,10 +91,13 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user // // - An error that is non-nil if pmas exist for only a subset of ar. // -// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for -// writing. ar.Length() != 0. vseg.Range().Contains(ar.Start). vmas must exist -// for all addresses in ar, and support accesses of type at (i.e. permission -// checks must have been performed against vmas). +// Preconditions: +// * mm.mappingMu must be locked. +// * mm.activeMu must be locked for writing. +// * ar.Length() != 0. +// * vseg.Range().Contains(ar.Start). +// * vmas must exist for all addresses in ar, and support accesses of type at +// (i.e. permission checks must have been performed against vmas). func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -136,9 +140,11 @@ func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar // exist. If this is not equal to ars, it returns a non-nil error explaining // why. // -// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for -// writing. vmas must exist for all addresses in ars, and support accesses of -// type at (i.e. permission checks must have been performed against vmas). +// Preconditions: +// * mm.mappingMu must be locked. +// * mm.activeMu must be locked for writing. +// * vmas must exist for all addresses in ars, and support accesses of type at +// (i.e. permission checks must have been performed against vmas). func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType) (usermem.AddrRangeSeq, error) { for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() { ar := arsit.Head() @@ -519,8 +525,10 @@ func privateAligned(ar usermem.AddrRange) usermem.AddrRange { // the memory it maps, isPMACopyOnWriteLocked will take ownership of the memory // and update the pma to indicate that it does not require copy-on-write. // -// Preconditions: vseg.Range().IsSupersetOf(pseg.Range()). mm.mappingMu must be -// locked. mm.activeMu must be locked for writing. +// Preconditions: +// * vseg.Range().IsSupersetOf(pseg.Range()). +// * mm.mappingMu must be locked. +// * mm.activeMu must be locked for writing. func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterator) bool { pma := pseg.ValuePtr() if !pma.needCOW { @@ -569,8 +577,10 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate // invalidateLocked removes pmas and AddressSpace mappings of those pmas for // addresses in ar. // -// Preconditions: mm.activeMu must be locked for writing. ar.Length() != 0. ar -// must be page-aligned. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -604,7 +614,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat } } -// Pin returns the platform.File ranges currently mapped by addresses in ar in +// Pin returns the memmap.File ranges currently mapped by addresses in ar in // mm, acquiring a reference on the returned ranges which the caller must // release by calling Unpin. If not all addresses are mapped, Pin returns a // non-nil error. Note that Pin may return both a non-empty slice of @@ -614,7 +624,9 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat // most I/O. It should only be used in contexts that would use get_user_pages() // in the Linux kernel. // -// Preconditions: ar.Length() != 0. ar must be page-aligned. +// Preconditions: +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -674,15 +686,15 @@ type PinnedRange struct { Source usermem.AddrRange // File is the mapped file. - File platform.File + File memmap.File // Offset is the offset into File at which this PinnedRange begins. Offset uint64 } -// FileRange returns the platform.File offsets mapped by pr. -func (pr PinnedRange) FileRange() platform.FileRange { - return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} +// FileRange returns the memmap.File offsets mapped by pr. +func (pr PinnedRange) FileRange() memmap.FileRange { + return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} } // Unpin releases the reference held by prs. @@ -694,9 +706,13 @@ func Unpin(prs []PinnedRange) { // movePMAsLocked moves all pmas in oldAR to newAR. // -// Preconditions: mm.activeMu must be locked for writing. oldAR.Length() != 0. -// oldAR.Length() <= newAR.Length(). !oldAR.Overlaps(newAR). -// mm.pmas.IsEmptyRange(newAR). oldAR and newAR must be page-aligned. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * oldAR.Length() != 0. +// * oldAR.Length() <= newAR.Length(). +// * !oldAR.Overlaps(newAR). +// * mm.pmas.IsEmptyRange(newAR). +// * oldAR and newAR must be page-aligned. func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { if checkInvariants { if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() { @@ -752,9 +768,11 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { // - An error that is non-nil if internal mappings exist for only a subset of // ar. // -// Preconditions: mm.activeMu must be locked for writing. -// pseg.Range().Contains(ar.Start). pmas must exist for all addresses in ar. -// ar.Length() != 0. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * pseg.Range().Contains(ar.Start). +// * pmas must exist for all addresses in ar. +// * ar.Length() != 0. // // Postconditions: getPMAInternalMappingsLocked does not invalidate iterators // into mm.pmas. @@ -784,8 +802,9 @@ func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar userm // internal mappings exist. If this is not equal to ars, it returns a non-nil // error explaining why. // -// Preconditions: mm.activeMu must be locked for writing. pmas must exist for -// all addresses in ar. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * pmas must exist for all addresses in ar. // // Postconditions: getVecPMAInternalMappingsLocked does not invalidate iterators // into mm.pmas. @@ -804,9 +823,12 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe // internalMappingsLocked returns internal mappings for addresses in ar. // -// Preconditions: mm.activeMu must be locked. Internal mappings must have been -// previously established for all addresses in ar. ar.Length() != 0. -// pseg.Range().Contains(ar.Start). +// Preconditions: +// * mm.activeMu must be locked. +// * Internal mappings must have been previously established for all addresses +// in ar. +// * ar.Length() != 0. +// * pseg.Range().Contains(ar.Start). func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -840,8 +862,10 @@ func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.Add // vecInternalMappingsLocked returns internal mappings for addresses in ars. // -// Preconditions: mm.activeMu must be locked. Internal mappings must have been -// previously established for all addresses in ars. +// Preconditions: +// * mm.activeMu must be locked. +// * Internal mappings must have been previously established for all addresses +// in ars. func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) safemem.BlockSeq { var ims []safemem.Block for ; !ars.IsEmpty(); ars = ars.Tail() { @@ -857,7 +881,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf } // incPrivateRef acquires a reference on private pages in fr. -func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { +func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) { mm.privateRefs.mu.Lock() defer mm.privateRefs.mu.Unlock() refSet := &mm.privateRefs.refs @@ -878,8 +902,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { } // decPrivateRef releases a reference on private pages in fr. -func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) { - var freed []platform.FileRange +func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) { + var freed []memmap.FileRange mm.privateRefs.mu.Lock() refSet := &mm.privateRefs.refs @@ -951,7 +975,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa // Discard internal mappings instead of trying to merge them, since merging // them requires an allocation and getting them again from the - // platform.File might not. + // memmap.File might not. pma1.internalMappings = safemem.BlockSeq{} return pma1, true } @@ -970,7 +994,9 @@ func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (p // findOrSeekPrevUpperBoundPMA returns mm.pmas.UpperBoundSegment(addr), but may do // so by scanning linearly backward from pgap. // -// Preconditions: mm.activeMu must be locked. addr <= pgap.Start(). +// Preconditions: +// * mm.activeMu must be locked. +// * addr <= pgap.Start(). func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr usermem.Addr, pgap pmaGapIterator) pmaIterator { if checkInvariants { if !pgap.Ok() { @@ -1012,12 +1038,14 @@ func (pseg pmaIterator) getInternalMappingsLocked() error { return nil } -func (pseg pmaIterator) fileRange() platform.FileRange { +func (pseg pmaIterator) fileRange() memmap.FileRange { return pseg.fileRangeOf(pseg.Range()) } -// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0. -func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { +// Preconditions: +// * pseg.Range().IsSupersetOf(ar). +// * ar.Length != 0. +func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { panic("terminal pma iterator") @@ -1032,5 +1060,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { pma := pseg.ValuePtr() pstart := pseg.Start() - return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} + return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 9ad52082d..2dbe5b751 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -16,10 +16,8 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -32,10 +30,10 @@ import ( // // +stateify savable type SpecialMappable struct { - refs.AtomicRefCount + SpecialMappableRefs mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange name string } @@ -44,15 +42,15 @@ type SpecialMappable struct { // SpecialMappable will use the given name in /proc/[pid]/maps. // // Preconditions: fr.Length() != 0. -func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable { +func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} - m.EnableLeakCheck("mm.SpecialMappable") + m.EnableLeakCheck() return &m } // DecRef implements refs.RefCounter.DecRef. -func (m *SpecialMappable) DecRef() { - m.AtomicRefCount.DecRefWithDestructor(func() { +func (m *SpecialMappable) DecRef(ctx context.Context) { + m.SpecialMappableRefs.DecRef(func() { m.mfp.MemoryFile().DecRef(m.fr) }) } @@ -126,7 +124,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider { // FileRange returns the offsets into MemoryFileProvider().MemoryFile() that // store the SpecialMappable's contents. -func (m *SpecialMappable) FileRange() platform.FileRange { +func (m *SpecialMappable) FileRange() memmap.FileRange { return m.fr } @@ -138,9 +136,12 @@ func (m *SpecialMappable) Length() uint64 { // NewSharedAnonMappable returns a SpecialMappable that implements the // semantics of mmap(MAP_SHARED|MAP_ANONYMOUS) and mappings of /dev/zero. // -// TODO(jamieliu): The use of SpecialMappable is a lazy code reuse hack. Linux -// uses an ephemeral file created by mm/shmem.c:shmem_zero_setup(); we should -// do the same to get non-zero device and inode IDs. +// TODO(gvisor.dev/issue/1624): Linux uses an ephemeral file created by +// mm/shmem.c:shmem_zero_setup(), and VFS2 does something analogous. VFS1 uses +// a SpecialMappable instead, incorrectly getting device and inode IDs of zero +// and causing memory for shared anonymous mappings to be allocated up-front +// instead of on first touch; this is to avoid exacerbating the fs.MountSource +// leak (b/143656263). Delete this function along with VFS1. func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) { if length == 0 { return nil, syserror.EINVAL diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 3f496aa9f..a2555ba1a 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -93,18 +92,6 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme } } else { opts.Offset = 0 - if !opts.Private { - if opts.MappingIdentity != nil { - return 0, syserror.EINVAL - } - m, err := NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) - if err != nil { - return 0, err - } - defer m.DecRef() - opts.MappingIdentity = m - opts.Mappable = m - } } if opts.Addr.RoundDown() != opts.Addr { @@ -166,7 +153,9 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme // populateVMA obtains pmas for addresses in ar in the given vma, and maps them // into mm.as if it is active. // -// Preconditions: mm.mappingMu must be locked. vseg.Range().IsSupersetOf(ar). +// Preconditions: +// * mm.mappingMu must be locked. +// * vseg.Range().IsSupersetOf(ar). func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) { if !vseg.ValuePtr().effectivePerms.Any() { // Linux doesn't populate inaccessible pages. See @@ -208,8 +197,9 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar u // preferable to populateVMA since it unlocks mm.mappingMu before performing // expensive operations that don't require it to be locked. // -// Preconditions: mm.mappingMu must be locked for writing. -// vseg.Range().IsSupersetOf(ar). +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * vseg.Range().IsSupersetOf(ar). // // Postconditions: mm.mappingMu will be unlocked. func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) { @@ -1191,7 +1181,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length ui mr := vseg.mappableRangeOf(vseg.Range().Intersect(ar)) mm.mappingMu.RUnlock() err := id.Msync(ctx, mr) - id.DecRef() + id.DecRef(ctx) if err != nil { return err } diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 16d8207e9..f769d8294 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -27,8 +27,9 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Preconditions: mm.mappingMu must be locked for writing. opts must be valid -// as defined by the checks in MMap. +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * opts must be valid as defined by the checks in MMap. func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, usermem.AddrRange, error) { if opts.MaxPerms != opts.MaxPerms.Effective() { panic(fmt.Sprintf("Non-effective MaxPerms %s cannot be enforced", opts.MaxPerms)) @@ -42,7 +43,12 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp Map32Bit: opts.Map32Bit, }) if err != nil { - return vmaIterator{}, usermem.AddrRange{}, err + // Can't force without opts.Unmap and opts.Fixed. + if opts.Force && opts.Unmap && opts.Fixed { + addr = opts.Addr + } else { + return vmaIterator{}, usermem.AddrRange{}, err + } } ar, _ := addr.ToRange(opts.Length) @@ -255,8 +261,9 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 { // // - An error that is non-nil if vmas exist for only a subset of ar. // -// Preconditions: mm.mappingMu must be locked for reading; it may be -// temporarily unlocked. ar.Length() != 0. +// Preconditions: +// * mm.mappingMu must be locked for reading; it may be temporarily unlocked. +// * ar.Length() != 0. func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -337,8 +344,10 @@ const guardBytes = 256 * usermem.PageSize // unmapLocked unmaps all addresses in ar and returns the resulting gap in // mm.vmas. // -// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0. -// ar must be page-aligned. +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -356,8 +365,10 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) // gap in mm.vmas. It does not remove pmas or AddressSpace mappings; clients // must do so before calling removeVMAsLocked. // -// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0. ar -// must be page-aligned. +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -377,7 +388,7 @@ func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRa vma.mappable.RemoveMapping(ctx, mm, vmaAR, vma.off, vma.canWriteMappableLocked()) } if vma.id != nil { - vma.id.DecRef() + vma.id.DecRef(ctx) } mm.usageAS -= uint64(vmaAR.Length()) if vma.isPrivateDataLocked() { @@ -446,7 +457,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa } if vma2.id != nil { - vma2.id.DecRef() + vma2.id.DecRef(context.Background()) } return vma1, true } @@ -462,7 +473,9 @@ func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (v return v, v2 } -// Preconditions: vseg.ValuePtr().mappable != nil. vseg.Range().Contains(addr). +// Preconditions: +// * vseg.ValuePtr().mappable != nil. +// * vseg.Range().Contains(addr). func (vseg vmaIterator) mappableOffsetAt(addr usermem.Addr) uint64 { if checkInvariants { if !vseg.Ok() { @@ -486,8 +499,10 @@ func (vseg vmaIterator) mappableRange() memmap.MappableRange { return vseg.mappableRangeOf(vseg.Range()) } -// Preconditions: vseg.ValuePtr().mappable != nil. -// vseg.Range().IsSupersetOf(ar). ar.Length() != 0. +// Preconditions: +// * vseg.ValuePtr().mappable != nil. +// * vseg.Range().IsSupersetOf(ar). +// * ar.Length() != 0. func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRange { if checkInvariants { if !vseg.Ok() { @@ -509,8 +524,10 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan return memmap.MappableRange{vma.off + uint64(ar.Start-vstart), vma.off + uint64(ar.End-vstart)} } -// Preconditions: vseg.ValuePtr().mappable != nil. -// vseg.mappableRange().IsSupersetOf(mr). mr.Length() != 0. +// Preconditions: +// * vseg.ValuePtr().mappable != nil. +// * vseg.mappableRange().IsSupersetOf(mr). +// * mr.Length() != 0. func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { if checkInvariants { if !vseg.Ok() { @@ -535,7 +552,9 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { // seekNextLowerBound returns mm.vmas.LowerBoundSegment(addr), but does so by // scanning linearly forward from vseg. // -// Preconditions: mm.mappingMu must be locked. addr >= vseg.Start(). +// Preconditions: +// * mm.mappingMu must be locked. +// * addr >= vseg.Start(). func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator { if checkInvariants { if !vseg.Ok() { diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index e1fcb175f..7a3311a70 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -36,14 +36,14 @@ go_template_instance( "trackGaps": "1", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "usage", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "usageInfo", "Functions": "usageSetFunctions", }, @@ -56,14 +56,14 @@ go_template_instance( "minDegree": "10", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "reclaim", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "reclaimSetValue", "Functions": "reclaimSetFunctions", }, @@ -89,7 +89,7 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/hostmm", - "//pkg/sentry/platform", + "//pkg/sentry/memmap", "//pkg/sentry/usage", "//pkg/state", "//pkg/state/wire", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index afab97c0a..626d1eaa4 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -33,14 +33,14 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// MemoryFile is a platform.File whose pages may be allocated to arbitrary +// MemoryFile is a memmap.File whose pages may be allocated to arbitrary // users. type MemoryFile struct { // opts holds options passed to NewMemoryFile. opts is immutable. @@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() { // to Allocate. // // Preconditions: length must be page-aligned and non-zero. -func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) { +func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) { if length == 0 || length%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid allocation length: %#x", length)) } @@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Find a range in the underlying file. fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) if !ok { - return platform.FileRange{}, syserror.ENOMEM + return memmap.FileRange{}, syserror.ENOMEM } // Expand the file if needed. @@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Round the new file size up to be chunk-aligned. newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask if err := f.file.Truncate(newFileSize); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } f.fileSize = newFileSize f.mappingsMu.Lock() @@ -409,16 +409,16 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi f.mappingsMu.Unlock() } - // Mark selected pages as in use. if f.opts.ManualZeroing { if err := f.forEachMappingSlice(fr, func(bs []byte) { for i := range bs { bs[i] = 0 } }); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } } + // Mark selected pages as in use. if !f.usage.Add(fr, usageInfo{ kind: kind, refs: 1, @@ -439,7 +439,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // space for mappings to be allocated downwards. // // Precondition: alignment must be a power of 2. -func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) { alignmentMask := alignment - 1 // Search for space in existing gaps, starting at the current end of the @@ -461,7 +461,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 break } if start := unalignedStart &^ alignmentMask; start >= gap.Start() { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } gap = gap.PrevLargeEnoughGap(length) @@ -475,7 +475,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 min = (min + alignmentMask) &^ alignmentMask if min+length < min { // Overflow: allocation would exceed the range of uint64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } // Determine the minimum file size required to fit this allocation at its end. @@ -484,7 +484,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 if newFileSize <= fileSize { if fileSize != 0 { // Overflow: allocation would exceed the range of int64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } newFileSize = chunkSize } @@ -496,7 +496,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 continue } if start := unalignedStart &^ alignmentMask; start >= min { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } } } @@ -507,23 +507,25 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 // nearest page. If this is shorter than length bytes due to an error returned // by r.ReadToBlocks(), it returns that error. // -// Preconditions: length > 0. length must be page-aligned. -func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) { +// Preconditions: +// * length > 0. +// * length must be page-aligned. +func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) { fr, err := f.Allocate(length, kind) if err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } dsts, err := f.MapInternal(fr, usermem.Write) if err != nil { f.DecRef(fr) - return platform.FileRange{}, err + return memmap.FileRange{}, err } n, err := safemem.ReadFullToBlocks(r, dsts) un := uint64(usermem.Addr(n).RoundDown()) if un < length { // Free unused memory and update fr to contain only the memory that is // still allocated. - f.DecRef(platform.FileRange{fr.Start + un, fr.End}) + f.DecRef(memmap.FileRange{fr.Start + un, fr.End}) fr.End = fr.Start + un } return fr, err @@ -540,7 +542,7 @@ const ( // will read zeroes. // // Preconditions: fr.Length() > 0. -func (f *MemoryFile) Decommit(fr platform.FileRange) error { +func (f *MemoryFile) Decommit(fr memmap.FileRange) error { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -560,7 +562,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error { return nil } -func (f *MemoryFile) markDecommitted(fr platform.FileRange) { +func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() // Since we're changing the knownCommitted attribute, we need to merge @@ -581,8 +583,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) { f.usage.MergeRange(fr) } -// IncRef implements platform.File.IncRef. -func (f *MemoryFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (f *MemoryFile) IncRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -600,8 +602,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) } -// DecRef implements platform.File.DecRef. -func (f *MemoryFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (f *MemoryFile) DecRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -637,8 +639,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } } -// MapInternal implements platform.File.MapInternal. -func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { if !fr.WellFormed() || fr.Length() == 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -664,7 +666,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) ( // forEachMappingSlice invokes fn on a sequence of byte slices that // collectively map all bytes in fr. -func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error { +func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error { mappings := f.mappings.Load().([]uintptr) for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { chunk := int(chunkStart >> chunkShift) @@ -944,7 +946,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( continue case !populated && populatedRun: // Finish the run by changing this segment. - runRange := platform.FileRange{ + runRange := memmap.FileRange{ Start: r.Start + uint64(populatedRunStart*usermem.PageSize), End: r.Start + uint64(i*usermem.PageSize), } @@ -1009,7 +1011,7 @@ func (f *MemoryFile) File() *os.File { return f.file } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (f *MemoryFile) FD() int { return int(f.file.Fd()) } @@ -1090,13 +1092,13 @@ func (f *MemoryFile) runReclaim() { // // Note that there returned range will be removed from tracking. It // must be reclaimed (removed from f.usage) at this point. -func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { +func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() for { for { if f.destroyed { - return platform.FileRange{}, false + return memmap.FileRange{}, false } if f.reclaimable { break @@ -1120,7 +1122,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } } -func (f *MemoryFile) markReclaimed(fr platform.FileRange) { +func (f *MemoryFile) markReclaimed(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) @@ -1167,8 +1169,10 @@ func (f *MemoryFile) startEvictionsLocked() bool { return startedAny } -// Preconditions: info == f.evictable[user]. !info.evicting. f.mu must be -// locked. +// Preconditions: +// * info == f.evictable[user]. +// * !info.evicting. +// * f.mu must be locked. func (f *MemoryFile) startEvictionGoroutineLocked(user EvictableMemoryUser, info *evictableMemoryUserInfo) { info.evicting = true f.evictionWG.Add(1) @@ -1222,11 +1226,11 @@ func (usageSetFunctions) MaxKey() uint64 { func (usageSetFunctions) ClearValue(val *usageInfo) { } -func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) { +func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) { return val1, val1 == val2 } -func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { +func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { return val, val } @@ -1270,10 +1274,10 @@ func (reclaimSetFunctions) MaxKey() uint64 { func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { } -func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { +func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { return reclaimSetValue{}, true } -func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { +func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { return reclaimSetValue{}, reclaimSetValue{} } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 453241eca..209b28053 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,39 +1,21 @@ load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - go_library( name = "platform", srcs = [ "context.go", - "file_range.go", "mmap_min_addr.go", "platform.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/atomicbitops", "//pkg/context", - "//pkg/log", - "//pkg/safecopy", - "//pkg/safemem", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/usage", - "//pkg/syserror", + "//pkg/sentry/memmap", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go index 57be41647..9dfac3eae 100644 --- a/pkg/sentry/platform/interrupt/interrupt.go +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -54,8 +54,9 @@ type Forwarder struct { // } // defer f.Disable() // -// Preconditions: r must not be nil. f must not already be forwarding -// interrupts to a Receiver. +// Preconditions: +// * r must not be nil. +// * f must not already be forwarding interrupts to a Receiver. func (f *Forwarder) Enable(r Receiver) bool { if r == nil { panic("nil Receiver") diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 4792454c4..323837fb1 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -9,12 +9,12 @@ go_library( "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", - "bluepill_amd64.s", "bluepill_amd64_unsafe.go", "bluepill_arm64.go", "bluepill_arm64.s", "bluepill_arm64_unsafe.go", "bluepill_fault.go", + "bluepill_impl_amd64.s", "bluepill_unsafe.go", "context.go", "filters_amd64.go", @@ -41,12 +41,14 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/atomicbitops", + "//pkg/context", "//pkg/cpuid", "//pkg/log", "//pkg/procid", "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/ring0", @@ -60,6 +62,7 @@ go_library( go_test( name = "kvm_test", srcs = [ + "kvm_amd64_test.go", "kvm_test.go", "virtual_map_test.go", ], @@ -78,3 +81,11 @@ go_test( "//pkg/usermem", ], ) + +genrule( + name = "bluepill_impl_amd64", + srcs = ["bluepill_amd64.s"], + outs = ["bluepill_impl_amd64.s"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index faf1d5e1c..af5c5e191 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem. } // MapFile implements platform.AddressSpace.MapFile. -func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { as.mu.Lock() defer as.mu.Unlock() @@ -247,3 +248,9 @@ func (as *addressSpace) Release() { // Drop all cached machine references. as.machine.dropPageTables(as.pageTables) } + +// PreFork implements platform.AddressSpace.PreFork. +func (as *addressSpace) PreFork() {} + +// PostFork implements platform.AddressSpace.PostFork. +func (as *addressSpace) PostFork() {} diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 2bc34a435..025ea93b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -19,11 +19,6 @@ // This is guaranteed to be zero. #define VCPU_CPU 0x0 -// CPU_SELF is the self reference in ring0's percpu. -// -// This is guaranteed to be zero. -#define CPU_SELF 0x0 - // Context offsets. // // Only limited use of the context is done in the assembly stub below, most is @@ -44,7 +39,7 @@ begin: LEAQ VCPU_CPU(AX), BX BYTE CLI; check_vcpu: - MOVQ CPU_SELF(GS), CX + MOVQ ENTRY_CPU_SELF(GS), CX CMPQ BX, CX JE right_vCPU wrong_vcpu: diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 03a98512e..0a54dd30d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -83,5 +83,34 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillReadyStopGuest(c *vCPU) bool { - return c.runData.readyForInterruptInjection != 0 + if c.runData.readyForInterruptInjection == 0 { + return false + } + + if c.runData.ifFlag == 0 { + // This is impossible if readyForInterruptInjection is 1. + throw("interrupts are disabled") + } + + // Disable interrupts if we are in the kernel space. + // + // When the Sentry switches into the kernel mode, it disables + // interrupts. But when goruntime switches on a goroutine which has + // been saved in the host mode, it restores flags and this enables + // interrupts. See the comment of UserFlagsSet for more details. + uregs := userRegs{} + err := c.getUserRegisters(&uregs) + if err != 0 { + throw("failed to get user registers") + } + + if ring0.IsKernelFlags(uregs.RFLAGS) { + uregs.RFLAGS &^= ring0.KernelFlagsClear + err = c.setUserRegisters(&uregs) + if err != 0 { + throw("failed to set user registers") + } + return false + } + return true } diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index dba563160..ed5ae03d3 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -49,7 +49,7 @@ func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { regs.Sp = context.Sp regs.Pc = context.Pc regs.Pstate = context.Pstate - regs.Pstate &^= uint64(ring0.KernelFlagsClear) + regs.Pstate &^= uint64(ring0.PsrFlagsClear) regs.Pstate |= ring0.KernelFlagsSet return } @@ -63,7 +63,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { context.Sp = regs.Sp context.Pc = regs.Pc context.Pstate = regs.Pstate - context.Pstate &^= uint64(ring0.UserFlagsClear) + context.Pstate &^= uint64(ring0.PsrFlagsClear) context.Pstate |= ring0.UserFlagsSet lazyVfp := c.GetLazyVFP() diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 8b64f3a1e..b35c930e2 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -41,7 +41,7 @@ func fpsimdPtr(addr *byte) *arch.FpsimdContext { func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { // If the vCPU is in user mode, we set the stack to the stored stack // value in the vCPU itself. We don't want to unwind the user stack. - if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t { + if guestRegs.Regs.Pstate&ring0.PsrModeMask == ring0.UserFlagsSet { regs := c.CPU.Registers() context.Regs[0] = regs.Regs[0] context.Sp = regs.Sp diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index e34f46aeb..a182e4f22 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -98,6 +98,10 @@ func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegi } errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { + // Store the physical address in the slot. This is used to + // avoid calls to handleBluepillFault in the future (see + // machine.mapPhysical). + atomic.StoreUintptr(&m.usedSlots[slot], physical) // Successfully added region; we can increment nextSlot and // allow another set to proceed here. atomic.StoreUint32(&m.nextSlot, slot+1) diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index bf357de1a..979be5d89 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index 6507121ea..6e6b76416 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -15,6 +15,7 @@ package kvm import ( + pkgcontext "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" @@ -37,7 +38,8 @@ type context struct { } // Switch runs the provided context in the given address space. -func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) { + as := mm.AddressSpace() localAS := as.(*addressSpace) // Grab a vCPU. @@ -88,3 +90,9 @@ func (c *context) Interrupt() { // Release implements platform.Context.Release(). func (c *context) Release() {} + +// FullStateChanged implements platform.Context.FullStateChanged. +func (c *context) FullStateChanged() {} + +// PullFullState implements platform.Context.PullFullState. +func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {} diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ae813e24e..d46946402 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -156,15 +156,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. pageTables := pagetables.New(newAllocator()) - applyPhysicalRegions(func(pr physicalRegion) bool { - // Map the kernel in the upper half. - pageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. - }) + k.machine.mapUpperHalf(pageTables) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go new file mode 100644 index 000000000..c0b4fd374 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +func TestSegments(t *testing.T) { + applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + testutil.SetTestSegments(regs) + for { + var si arch.SignalInfo + if _, err := c.SwitchToUser(ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: dummyFPState, + PageTables: pt, + FullRestore: true, + }, &si); err == platform.ErrContextInterrupt { + continue // Retry. + } else if err != nil { + t.Errorf("application segment check with full restore got unexpected error: %v", err) + } + if err := testutil.CheckTestSegments(regs); err != nil { + t.Errorf("application segment check with full restore failed: %v", err) + } + break // Done. + } + return false + }) +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go index 6531bae1d..48ccf8474 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go @@ -22,7 +22,8 @@ import ( ) var ( - runDataSize int + runDataSize int + hasGuestPCID bool ) func updateSystemValues(fd int) error { @@ -33,6 +34,7 @@ func updateSystemValues(fd int) error { } // Save the data. runDataSize = int(sz) + hasGuestPCID = true // Success. return nil diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 3bf918446..5f627a016 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -32,6 +32,7 @@ const ( _KVM_SET_REGS = 0x4090ae82 _KVM_SET_SREGS = 0x4138ae84 _KVM_GET_REGS = 0x8090ae81 + _KVM_GET_SREGS = 0x8138ae83 _KVM_GET_SUPPORTED_CPUID = 0xc008ae05 _KVM_SET_CPUID2 = 0x4008ae90 _KVM_SET_SIGNAL_MASK = 0x4004ae8b @@ -56,6 +57,7 @@ const ( // KVM capability options. const ( + _KVM_CAP_MAX_MEMSLOTS = 0x0a _KVM_CAP_MAX_VCPUS = 0x42 _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 _KVM_CAP_VCPU_EVENTS = 0x29 @@ -64,6 +66,7 @@ const ( // KVM limits. const ( + _KVM_NR_MEMSLOTS = 0x100 _KVM_NR_VCPUS = 0xff _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 6f0539c29..9a7be3655 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -72,6 +72,7 @@ const ( _TCR_T0SZ_VA48 = 64 - 48 // VA=48 _TCR_T1SZ_VA48 = 64 - 48 // VA=48 + _TCR_A1 = 1 << 22 _TCR_ASID16 = 1 << 36 _TCR_TBI0 = 1 << 37 @@ -116,6 +117,17 @@ const ( // Arm64: Exception Syndrome Register EL1. const ( + _ESR_ELx_EC_SHIFT = 26 + _ESR_ELx_EC_MASK = 0x3F << _ESR_ELx_EC_SHIFT + + _ESR_ELx_EC_IMP_DEF = 0x1f + _ESR_ELx_EC_IABT_LOW = 0x20 + _ESR_ELx_EC_IABT_CUR = 0x21 + _ESR_ELx_EC_PC_ALIGN = 0x22 + + _ESR_ELx_CM = 1 << 8 + _ESR_ELx_WNR = 1 << 6 + _ESR_ELx_FSC = 0x3F _ESR_SEGV_MAPERR_L0 = 0x4 diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 6c8f4fa28..45b3180f1 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -262,30 +262,6 @@ func TestRegistersFault(t *testing.T) { }) } -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - func TestBounce(t *testing.T) { applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { go func() { diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 6c54712d1..75da253c5 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -43,9 +43,6 @@ type machine struct { // kernel is the set of global structures. kernel ring0.Kernel - // mappingCache is used for mapPhysical. - mappingCache sync.Map - // mu protects vCPUs. mu sync.RWMutex @@ -63,6 +60,12 @@ type machine struct { // maxVCPUs is the maximum number of vCPUs supported by the machine. maxVCPUs int + // maxSlots is the maximum number of memory slots supported by the machine. + maxSlots int + + // usedSlots is the set of used physical addresses (sorted). + usedSlots []uintptr + // nextID is the next vCPU ID. nextID uint32 } @@ -152,7 +155,7 @@ func (m *machine) newVCPU() *vCPU { fd: int(fd), machine: m, } - c.CPU.Init(&m.kernel, c) + c.CPU.Init(&m.kernel, c.id, c) m.vCPUsByID[c.id] = c // Ensure the signal mask is correct. @@ -180,10 +183,8 @@ func newMachine(vm int) (*machine, error) { // Create the machine. m := &machine{fd: vm} m.available.L = &m.mu - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }) + // Pull the maximum vCPUs. maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) if errno != 0 { m.maxVCPUs = _KVM_NR_VCPUS @@ -191,10 +192,21 @@ func newMachine(vm int) (*machine, error) { m.maxVCPUs = int(maxVCPUs) } log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) - - // Create the vCPUs map/slices. m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + m.kernel.Init(ring0.KernelOpts{ + PageTables: pagetables.New(newAllocator()), + }, m.maxVCPUs) + + // Pull the maximum slots. + maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) + if errno != 0 { + m.maxSlots = _KVM_NR_MEMSLOTS + } else { + m.maxSlots = int(maxSlots) + } + log.Debugf("The maximum number of slots is %d.", m.maxSlots) + m.usedSlots = make([]uintptr, m.maxSlots) // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These @@ -207,15 +219,9 @@ func newMachine(vm int) (*machine, error) { pagetables.MapOpts{AccessType: usermem.AnyAccess}, pr.physical) - // And keep everything in the upper half. - m.kernel.PageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. }) + m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -272,6 +278,20 @@ func newMachine(vm int) (*machine, error) { return m, nil } +// hasSlot returns true iff the given address is mapped. +// +// This must be done via a linear scan. +// +//go:nosplit +func (m *machine) hasSlot(physical uintptr) bool { + for i := 0; i < len(m.usedSlots); i++ { + if p := atomic.LoadUintptr(&m.usedSlots[i]); p == physical { + return true + } + } + return false +} + // mapPhysical checks for the mapping of a physical range, and installs one if // not available. This attempts to be efficient for calls in the hot path. // @@ -286,8 +306,8 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg panic("mapPhysical on unknown physical address") } - if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { - // Not present in the cache; requires setting the slot. + // Is this already mapped? Check the usedSlots. + if !m.hasSlot(physicalStart) { if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } @@ -339,6 +359,11 @@ func (m *machine) Destroy() { // Get gets an available vCPU. // // This will return with the OS thread locked. +// +// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points +// to the vCPU in which the OS thread TID is running. So if Get() returns with +// the corrent context in guest, the vCPU of it must be the same as what +// Get() returns. func (m *machine) Get() *vCPU { m.mu.RLock() runtime.LockOSThread() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index acc823ba6..54e721bb1 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -144,6 +144,7 @@ func (c *vCPU) initArchState() error { // Set the entrypoint for the kernel. kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer()) kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer()) + kernelUserRegs.RSP = c.StackTop() kernelUserRegs.RFLAGS = ring0.KernelFlagsSet // Set the system registers. @@ -152,8 +153,8 @@ func (c *vCPU) initArchState() error { } // Set the user registers. - if err := c.setUserRegisters(&kernelUserRegs); err != nil { - return err + if errno := c.setUserRegisters(&kernelUserRegs); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) } // Allocate some floating point state save area for the local vCPU. @@ -345,3 +346,43 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } + +var execRegions []region + +func init() { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { + return + } + + if vr.accessType.Execute { + execRegions = append(execRegions, vr.region) + } + }) +} + +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + for _, r := range execRegions { + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossilbe translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) + } + for start, end := range m.kernel.EntryRegions() { + regionLen := end - start + physical, length, ok := translateToPhysical(start) + if !ok || length < regionLen { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|start), + regionLen, + pagetables.MapOpts{AccessType: usermem.ReadWrite}, + physical) + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 290f035dd..330f29065 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -137,15 +137,17 @@ func (c *vCPU) setSignalMask() error { } // setUserRegisters sets user registers in the vCPU. -func (c *vCPU) setUserRegisters(uregs *userRegs) error { +// +//go:nosplit +func (c *vCPU) setUserRegisters(uregs *userRegs) syscall.Errno { if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_REGS, uintptr(unsafe.Pointer(uregs))); errno != 0 { - return fmt.Errorf("error setting user registers: %v", errno) + return errno } - return nil + return 0 } // getUserRegisters reloads user registers in the vCPU. @@ -175,3 +177,17 @@ func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { } return nil } + +// getSystemRegisters sets system registers. +// +//go:nosplit +func (c *vCPU) getSystemRegisters(sregs *systemRegs) syscall.Errno { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return errno + } + return 0 +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index f3bf973de..2df762991 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -19,6 +19,7 @@ package kvm import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -48,6 +49,18 @@ const ( poolPCIDs = 8 ) +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + applyPhysicalRegions(func(pr physicalRegion) bool { + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|pr.virtual), + pr.length, + pagetables.MapOpts{AccessType: usermem.AnyAccess}, + pr.physical) + + return true // Keep iterating. + }) +} + // Get all read-only physicalRegions. func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { var rdonlyRegions []region @@ -125,71 +138,59 @@ func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.Acc return usermem.NoAccess, platform.ErrContextSignal } +// isInstructionAbort returns true if it is an instruction abort. +// +//go:nosplit +func isInstructionAbort(code uint64) bool { + value := (code & _ESR_ELx_EC_MASK) >> _ESR_ELx_EC_SHIFT + return value == _ESR_ELx_EC_IABT_LOW +} + +// isWriteFault returns whether it is a write fault. +// +//go:nosplit +func isWriteFault(code uint64) bool { + if isInstructionAbort(code) { + return false + } + + return (code & _ESR_ELx_WNR) != 0 +} + // fault generates an appropriate fault return. // //go:nosplit func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + bluepill(c) // Probably no-op, but may not be. faultAddr := c.GetFaultAddr() code, user := c.ErrorCode() + if !user { + // The last fault serviced by this CPU was not a user + // fault, so we can't reliably trust the faultAddr or + // the code provided here. We need to re-execute. + return usermem.NoAccess, platform.ErrContextInterrupt + } + // Reset the pointed SignalInfo. *info = arch.SignalInfo{Signo: signal} info.SetAddr(uint64(faultAddr)) - read := true - write := false - execute := true - ret := code & _ESR_ELx_FSC switch ret { case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3: info.Code = 1 //SEGV_MAPERR - read = false - write = true - execute = false case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3: info.Code = 2 // SEGV_ACCERR. - read = true - write = false - execute = false default: info.Code = 2 } - if !user { - read = true - write = false - execute = true - - } accessType := usermem.AccessType{ - Read: read, - Write: write, - Execute: execute, + Read: !isWriteFault(uint64(code)), + Write: isWriteFault(uint64(code)), + Execute: isInstructionAbort(uint64(code)), } return accessType, platform.ErrContextSignal } - -// retryInGuest runs the given function in guest mode. -// -// If the function does not complete in guest mode (due to execution of a -// system call due to a GC stall, for example), then it will be retried. The -// given function must be idempotent as a result of the retry mechanism. -func (m *machine) retryInGuest(fn func()) { - c := m.Get() - defer m.Put(c) - for { - c.ClearErrorCode() // See below. - bluepill(c) // Force guest mode. - fn() // Execute the given function. - _, user := c.ErrorCode() - if user { - // If user is set, then we haven't bailed back to host - // mode via a kernel exception or system call. We - // consider the full function to have executed in guest - // mode and we can return. - break - } - } -} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 8bed34922..537419657 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -60,7 +61,6 @@ func (c *vCPU) initArchState() error { reg.addr = uint64(reflect.ValueOf(&data).Pointer()) regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer()) - vcpuInit.target = _KVM_ARM_TARGET_GENERIC_V8 vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2) if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, @@ -78,21 +78,8 @@ func (c *vCPU) initArchState() error { return err } - // sctlr_el1 - regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.getOneRegister(®Get); err != nil { - return err - } - - dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) - data = dataGet - reg.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.setOneRegister(®); err != nil { - return err - } - // tcr_el1 - data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1 reg.id = _KVM_ARM64_REGS_TCR_EL1 if err := c.setOneRegister(®); err != nil { return err @@ -116,7 +103,7 @@ func (c *vCPU) initArchState() error { c.SetTtbr0Kvm(uintptr(data)) // ttbr1_el1 - data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1) reg.id = _KVM_ARM64_REGS_TTBR1_EL1 if err := c.setOneRegister(®); err != nil { @@ -163,10 +150,12 @@ func (c *vCPU) initArchState() error { // the MMIO address base. arm64HypercallMMIOBase = toLocation - data = ring0.PsrDefaultSet | ring0.KernelFlagsSet - reg.id = _KVM_ARM64_REGS_PSTATE - if err := c.setOneRegister(®); err != nil { - return err + // Initialize the PCID database. + if hasGuestPCID { + // Note that NewPCIDs may return a nil table here, in which + // case we simply don't use PCID support (see below). In + // practice, this should not happen, however. + c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) } c.floatingPointState = arch.NewFloatingPointData() @@ -247,6 +236,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) } + // Assign PCIDs. + if c.PCIDs != nil { + var requireFlushPCID bool // Force a flush? + switchOpts.UserASID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables) + switchOpts.Flush = switchOpts.Flush || requireFlushPCID + } + var vector ring0.Vector ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0) c.SetTtbr0App(uintptr(ttbr0App)) @@ -275,8 +271,16 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return c.fault(int32(syscall.SIGSEGV), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt + case ring0.El0Sync_undef, + ring0.El1Sync_undef: + *info = arch.SignalInfo{ + Signo: int32(syscall.SIGILL), + Code: 1, // ILL_ILLOPC (illegal opcode). + } + info.SetAddr(switchOpts.Registers.Pc) // Include address. + return usermem.AccessType{}, platform.ErrContextSignal default: - return usermem.NoAccess, platform.ErrContextSignal + panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) } } diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 9f86f6a7a..607c82156 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go index ca902c8c1..4dad877ba 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -56,5 +56,9 @@ func CheckTestRegs(regs *arch.Registers, full bool) (err error) { err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) } } + // Check tls. + if need := ^uint64(11); regs.TPIDR_EL0 != need { + err = addRegisterMismatch(err, "tpdir_el0", regs.TPIDR_EL0, need) + } return } diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 0bebee852..6caf7282d 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -52,6 +52,8 @@ start: TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 NO_LOCAL_POINTERS + // gc will touch fpsimd, so we should test it. + // such as in <runtime.deductSweepCredit>. FMOVD $(9.9), F0 MOVD $SYS_GETPID, R8 // getpid SVC @@ -102,5 +104,15 @@ isNaN: TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 TWIDDLE_REGS() + MSR R10, TPIDR_EL0 + // Trapped in el0_svc. SVC RET // never reached + +TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 + TWIDDLE_REGS() + MSR R10, TPIDR_EL0 + // Trapped in el0_ia. + // Branch to Register branches unconditionally to an address in <Rn>. + JMP (R6) // <=> br x6, must fault + RET // never reached diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go index c8897d34f..4dcdbf8a7 100644 --- a/pkg/sentry/platform/kvm/virtual_map.go +++ b/pkg/sentry/platform/kvm/virtual_map.go @@ -34,7 +34,7 @@ type virtualRegion struct { } // mapsLine matches a single line from /proc/PID/maps. -var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)") +var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2,3}:[0-9a-f]{2,} [0-9]+\\s+(.*)") // excludeRegion returns true if these regions should be excluded from the // physical map. Virtual regions need to be excluded if get_user_pages will diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 171513f3f..530e779b0 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -22,9 +22,10 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -114,6 +115,17 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error { panic("This platform does not support CPU preemption detection") } +// MemoryManager represents an abstraction above the platform address space +// which manages memory mappings and their contents. +type MemoryManager interface { + //usermem.IO provides access to the contents of a virtual memory space. + usermem.IO + // MMap establishes a memory mapping. + MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error) + // AddressSpace returns the AddressSpace bound to mm. + AddressSpace() AddressSpace +} + // Context represents the execution context for a single thread. type Context interface { // Switch resumes execution of the thread specified by the arch.Context @@ -143,7 +155,36 @@ type Context interface { // concurrent call to Switch(). // // - ErrContextCPUPreempted: See the definition of that error for details. - Switch(as AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) + Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) + + // PullFullState() pulls a full state of the application thread. + // + // A platform can support lazy loading/restoring of a thread state + // which includes registers and a floating point state. + // + // For example, when the Sentry handles a system call, it may have only + // syscall arguments without other registers and a floating point + // state. And in this case, if the Sentry will need to construct a + // signal frame to call a signal handler, it will need to call + // PullFullState() to load all registers and FPU state. + // + // Preconditions: The caller must be running on the task goroutine. + PullFullState(as AddressSpace, ac arch.Context) + + // FullStateChanged() indicates that a thread state has been changed by + // the Sentry. This happens in case of the rt_sigreturn, execve, etc. + // + // First, it indicates that the Sentry has the full state of the thread + // and PullFullState() has to do nothing if it is called after + // FullStateChanged(). + // + // Second, it forces restoring the full state of the application + // thread. A platform can support lazy loading/restoring of a thread + // state. This means that if the Sentry has not changed a thread state, + // the platform may not restore it. + // + // Preconditions: The caller must be running on the task goroutine. + FullStateChanged() // Interrupt interrupts a concurrent call to Switch(), causing it to return // ErrContextInterrupt. @@ -204,20 +245,32 @@ type AddressSpace interface { // physical memory) to the mapping. The precommit flag is advisory and // implementations may choose to ignore it. // - // Preconditions: addr and fr must be page-aligned. fr.Length() > 0. - // at.Any() == true. At least one reference must be held on all pages in - // fr, and must continue to be held as long as pages are mapped. - MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error + // Preconditions: + // * addr and fr must be page-aligned. + // * fr.Length() > 0. + // * at.Any() == true. + // * At least one reference must be held on all pages in fr, and must + // continue to be held as long as pages are mapped. + MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error // Unmap unmaps the given range. // - // Preconditions: addr is page-aligned. length > 0. + // Preconditions: + // * addr is page-aligned. + // * length > 0. Unmap(addr usermem.Addr, length uint64) // Release releases this address space. After releasing, a new AddressSpace // must be acquired via platform.NewAddressSpace(). Release() + // PreFork() is called before creating a copy of AddressSpace. This + // guarantees that this address space will be in a consistent state. + PreFork() + + // PostFork() is called after creating a copy of AddressSpace. + PostFork() + // AddressSpaceIO methods are supported iff the associated platform's // Platform.SupportsAddressSpaceIO() == true. AddressSpaces for which this // does not hold may panic if AddressSpaceIO methods are invoked. @@ -310,52 +363,6 @@ func (f SegmentationFault) Error() string { return fmt.Sprintf("segmentation fault at %#x", f.Addr) } -// File represents a host file that may be mapped into an AddressSpace. -type File interface { - // All pages in a File are reference-counted. - - // IncRef increments the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. (The File - // interface does not provide a way to acquire an initial reference; - // implementors may define mechanisms for doing so.) - IncRef(fr FileRange) - - // DecRef decrements the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. - DecRef(fr FileRange) - - // MapInternal returns a mapping of the given file offsets in the invoking - // process' address space for reading and writing. - // - // Note that fr.Start and fr.End need not be page-aligned. - // - // Preconditions: fr.Length() > 0. At least one reference must be held on - // all pages in fr. - // - // Postconditions: The returned mapping is valid as long as at least one - // reference is held on the mapped pages. - MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) - - // FD returns the file descriptor represented by the File. - // - // The only permitted operation on the returned file descriptor is to map - // pages from it consistent with the requirements of AddressSpace.MapFile. - FD() int -} - -// FileRange represents a range of uint64 offsets into a File. -// -// type FileRange <generated using go_generics> - -// String implements fmt.Stringer.String. -func (fr FileRange) String() string { - return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) -} - // Requirements is used to specify platform specific requirements. type Requirements struct { // RequiresCurrentPIDNS indicates that the sandbox has to be started in the diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 30402c2df..fc43cc3c0 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -24,12 +24,13 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/procid", "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/hostcpu", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sync", diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go index 1e07cfd0d..b0970e356 100644 --- a/pkg/sentry/platform/ptrace/filters.go +++ b/pkg/sentry/platform/ptrace/filters.go @@ -24,10 +24,9 @@ import ( // SyscallFilters returns syscalls made exclusively by the ptrace platform. func (*PTrace) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - unix.SYS_GETCPU: {}, - unix.SYS_SCHED_SETAFFINITY: {}, - syscall.SYS_PTRACE: {}, - syscall.SYS_TGKILL: {}, - syscall.SYS_WAIT4: {}, + unix.SYS_GETCPU: {}, + syscall.SYS_PTRACE: {}, + syscall.SYS_TGKILL: {}, + syscall.SYS_WAIT4: {}, } } diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 08d055e05..b52d0fbd8 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -48,6 +48,7 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" + pkgcontext "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" @@ -95,7 +96,8 @@ type context struct { } // Switch runs the provided context in the given address space. -func (c *context) Switch(as platform.AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) { + as := mm.AddressSpace() s := as.(*subprocess) isSyscall := s.switchToApp(c, ac) @@ -180,6 +182,12 @@ func (c *context) Interrupt() { // Release implements platform.Context.Release(). func (c *context) Release() {} +// FullStateChanged implements platform.Context.FullStateChanged. +func (c *context) FullStateChanged() {} + +// PullFullState implements platform.Context.PullFullState. +func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {} + // PTrace represents a collection of ptrace subprocesses. type PTrace struct { platform.MMapMinAddr diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 2389423b0..812ab80ef 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -517,11 +518,6 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { } defer c.interrupt.Disable() - // Ensure that the CPU set is bound appropriately; this makes the - // emulation below several times faster, presumably by avoiding - // interprocessor wakeups and by simplifying the schedule. - t.bind() - // Set registers. if err := t.setRegs(regs); err != nil { panic(fmt.Sprintf("ptrace set regs (%+v) failed: %v", regs, err)) @@ -616,7 +612,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp } // MapFile implements platform.AddressSpace.MapFile. -func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { var flags int if precommit { flags |= syscall.MAP_POPULATE @@ -661,3 +657,9 @@ func (s *subprocess) Unmap(addr usermem.Addr, length uint64) { panic(fmt.Sprintf("munmap(%x, %x)) failed: %v", addr, length, err)) } } + +// PreFork implements platform.AddressSpace.PreFork. +func (s *subprocess) PreFork() {} + +// PostFork implements platform.AddressSpace.PostFork. +func (s *subprocess) PostFork() {} diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 84b699f0d..020bbda79 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -201,7 +201,7 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi seccomp.RuleSet{ Rules: seccomp.SyscallRules{ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + {seccomp.EqualTo(linux.ARCH_SET_CPUID), seccomp.EqualTo(0)}, }, }, Action: linux.SECCOMP_RET_ALLOW, diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index 2ce528601..8548853da 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -80,9 +80,9 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro Rules: seccomp.SyscallRules{ syscall.SYS_CLONE: []seccomp.Rule{ // Allow creation of new subprocesses (used by the master). - {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)}, + {seccomp.EqualTo(syscall.CLONE_FILES | syscall.SIGKILL)}, // Allow creation of new threads within a single address space (used by addresss spaces). - {seccomp.AllowValue( + {seccomp.EqualTo( syscall.CLONE_FILES | syscall.CLONE_FS | syscall.CLONE_SIGHAND | @@ -97,14 +97,14 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // For the stub prctl dance (all). syscall.SYS_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)}, + {seccomp.EqualTo(syscall.PR_SET_PDEATHSIG), seccomp.EqualTo(syscall.SIGKILL)}, }, syscall.SYS_GETPPID: {}, // For the stub to stop itself (all). syscall.SYS_GETPID: {}, syscall.SYS_KILL: []seccomp.Rule{ - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SIGSTOP)}, }, // Injected to support the address space operations. @@ -115,7 +115,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }) } rules = appendArchSeccompRules(rules, defaultAction) - instrs, err := seccomp.BuildProgram(rules, defaultAction) + instrs, err := seccomp.BuildProgram(rules, defaultAction, defaultAction) if err != nil { return nil, err } diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 245b20722..533e45497 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,29 +18,12 @@ package ptrace import ( - "sync/atomic" "syscall" "unsafe" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/hostcpu" - "gvisor.dev/gvisor/pkg/sync" ) -// maskPool contains reusable CPU masks for setting affinity. Unfortunately, -// runtime.NumCPU doesn't actually record the number of CPUs on the system, it -// just records the number of CPUs available in the scheduler affinity set at -// startup. This may a) change over time and b) gives a number far lower than -// the maximum indexable CPU. To prevent lots of allocation in the hot path, we -// use a pool to store large masks that we can reuse during bind. -var maskPool = sync.Pool{ - New: func() interface{} { - const maxCPUs = 1024 // Not a hard limit; see below. - return make([]uintptr, maxCPUs/64) - }, -} - // unmaskAllSignals unmasks all signals on the current thread. // //go:nosplit @@ -49,47 +32,3 @@ func unmaskAllSignals() syscall.Errno { _, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0) return errno } - -// setCPU sets the CPU affinity. -func (t *thread) setCPU(cpu uint32) error { - mask := maskPool.Get().([]uintptr) - n := int(cpu / 64) - v := uintptr(1 << uintptr(cpu%64)) - if n >= len(mask) { - // See maskPool note above. We've actually exceeded the number - // of available cores. Grow the mask and return it. - mask = make([]uintptr, n+1) - } - mask[n] |= v - if _, _, errno := syscall.RawSyscall( - unix.SYS_SCHED_SETAFFINITY, - uintptr(t.tid), - uintptr(len(mask)*8), - uintptr(unsafe.Pointer(&mask[0]))); errno != 0 { - return errno - } - mask[n] &^= v - maskPool.Put(mask) - return nil -} - -// bind attempts to ensure that the thread is on the same CPU as the current -// thread. This provides no guarantees as it is fundamentally a racy operation: -// CPU sets may change and we may be rescheduled in the middle of this -// operation. As a result, no failures are reported. -// -// Precondition: the current runtime thread should be locked. -func (t *thread) bind() { - currentCPU := hostcpu.GetCPU() - - if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { - // Set the affinity on the thread and save the CPU for next - // round; we don't expect CPUs to bounce around too frequently. - // - // (It's worth noting that we could move CPUs between this point - // and when the tracee finishes executing. But that would be - // roughly the status quo anyways -- we're just maximizing our - // chances of colocation, not guaranteeing it.) - t.setCPU(currentCPU) - } -} diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index 0bee995e4..7ee20d89a 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 8122ac6e2..87a573cc4 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -26,30 +26,31 @@ const ( _PMD_PGT_SIZE = 0x4000 _PTE_PGT_BASE = 0x7000 _PTE_PGT_SIZE = 0x1000 - - _PSR_D_BIT = 0x00000200 - _PSR_A_BIT = 0x00000100 - _PSR_I_BIT = 0x00000080 - _PSR_F_BIT = 0x00000040 ) const ( - // PSR bits - PSR_MODE_EL0t = 0x00000000 - PSR_MODE_EL1t = 0x00000004 - PSR_MODE_EL1h = 0x00000005 - PSR_MODE_MASK = 0x0000000f + // DAIF bits:debug, sError, IRQ, FIQ. + _PSR_D_BIT = 0x00000200 + _PSR_A_BIT = 0x00000100 + _PSR_I_BIT = 0x00000080 + _PSR_F_BIT = 0x00000040 + _PSR_DAIF_SHIFT = 6 + _PSR_DAIF_MASK = 0xf << _PSR_DAIF_SHIFT - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = PSR_MODE_EL1h + // PSR bits. + _PSR_MODE_EL0t = 0x00000000 + _PSR_MODE_EL1t = 0x00000004 + _PSR_MODE_EL1h = 0x00000005 + _PSR_MODE_MASK = 0x0000000f - // UserFlagsSet are always set in userspace. - UserFlagsSet = PSR_MODE_EL0t + PsrFlagsClear = _PSR_MODE_MASK | _PSR_DAIF_MASK + PsrModeMask = _PSR_MODE_MASK - KernelFlagsClear = PSR_MODE_MASK - UserFlagsClear = PSR_MODE_MASK + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _PSR_MODE_EL1h | _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT - PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT + // UserFlagsSet are always set in userspace. + UserFlagsSet = _PSR_MODE_EL0t ) // Vector is an exception vector. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 9c6c2cf5c..f617519fa 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -76,15 +76,42 @@ type KernelOpts struct { type KernelArchState struct { KernelOpts + // cpuEntries is array of kernelEntry for all cpus + cpuEntries []kernelEntry + // globalIDT is our set of interrupt gates. - globalIDT idt64 + globalIDT *idt64 } -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { +// kernelEntry contains minimal CPU-specific arch state +// that can be mapped at the upper of the address space. +// Malicious APP might steal info from it via CPU bugs. +type kernelEntry struct { // stack is the stack used for interrupts on this CPU. stack [256]byte + // scratch space for temporary usage. + scratch0 uint64 + scratch1 uint64 + + // stackTop is the top of the stack. + stackTop uint64 + + // cpuSelf is back reference to CPU. + cpuSelf *CPU + + // kernelCR3 is the cr3 used for sentry kernel. + kernelCR3 uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { // errorCode is the error code from the last exception. errorCode uintptr @@ -97,11 +124,7 @@ type CPUArchState struct { // exception. errorType uintptr - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 + *kernelEntry } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 0e2ab716c..508236e46 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -77,6 +77,9 @@ type CPUArchState struct { // lazyVFP is the value of cpacr_el1. lazyVFP uintptr + + // appASID is the asid value of guest application. + appASID uintptr } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go index 7fa43c2f5..d87b1fd00 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/sentry/platform/ring0/entry_amd64.go @@ -36,12 +36,15 @@ func sysenter() // This must be called prior to sysret/iret. func swapgs() +// jumpToKernel jumps to the kernel version of the current RIP. +func jumpToKernel() + // sysret returns to userspace from a system call. // // The return code is the vector that interrupted execution. // // See stubs.go for a note regarding the frame size of this function. -func sysret(*CPU, *arch.Registers) Vector +func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // "iret is the cadillac of CPL switching." // @@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector // iret is nearly identical to sysret, except an iret is used to fully restore // all user state. This must be called in cases where all registers need to be // restored. -func iret(*CPU, *arch.Registers) Vector +func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // exception is the generic exception entry. // diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s index 02df38331..f59747df3 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/sentry/platform/ring0/entry_amd64.s @@ -63,6 +63,15 @@ MOVQ offset+PTRACE_RSI(reg), SI; \ MOVQ offset+PTRACE_RDI(reg), DI; +// WRITE_CR3() writes the given CR3 value. +// +// The code corresponds to: +// +// mov %rax, %cr3 +// +#define WRITE_CR3() \ + BYTE $0x0f; BYTE $0x22; BYTE $0xd8; + // SWAP_GS swaps the kernel GS (CPU). #define SWAP_GS() \ BYTE $0x0F; BYTE $0x01; BYTE $0xf8; @@ -75,15 +84,9 @@ #define SYSRET64() \ BYTE $0x48; BYTE $0x0f; BYTE $0x07; -// LOAD_KERNEL_ADDRESS loads a kernel address. -#define LOAD_KERNEL_ADDRESS(from, to) \ - MOVQ from, to; \ - ORQ ·KernelStartAddress(SB), to; - // LOAD_KERNEL_STACK loads the kernel stack. -#define LOAD_KERNEL_STACK(from) \ - LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \ - LEAQ CPU_STACK_TOP(SP), SP; +#define LOAD_KERNEL_STACK(entry) \ + MOVQ ENTRY_STACK_TOP(entry), SP; // See kernel.go. TEXT ·Halt(SB),NOSPLIT,$0 @@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0 SWAP_GS() RET +// jumpToKernel changes execution to the kernel address space. +// +// This works by changing the return value to the kernel version. +TEXT ·jumpToKernel(SB),NOSPLIT,$0 + MOVQ 0(SP), AX + ORQ ·KernelStartAddress(SB), AX // Future return value. + MOVQ AX, 0(SP) + RET + // See entry_amd64.go. TEXT ·sysret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) + // save SP AX userCR3 on the kernel stack. + MOVQ CPU_ENTRY(BX), BX + LOAD_KERNEL_STACK(BX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_RAX(AX) + PUSHQ CX + // Restore user register state. REGISTERS_LOAD(AX, 0) MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET. MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET. - MOVQ PTRACE_RSP(AX), SP // Restore the stack directly. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + + // restore userCR3, AX, SP. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. + POPQ SP // Restore SP. SYSRET64() // See entry_amd64.go. TEXT ·iret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) // Build an IRET frame & restore state. + MOVQ CPU_ENTRY(BX), BX LOAD_KERNEL_STACK(BX) - MOVQ PTRACE_SS(AX), BX; PUSHQ BX - MOVQ PTRACE_RSP(AX), CX; PUSHQ CX - MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX - MOVQ PTRACE_CS(AX), DI; PUSHQ DI - MOVQ PTRACE_RIP(AX), SI; PUSHQ SI - REGISTERS_LOAD(AX, 0) // Restore most registers. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + PUSHQ PTRACE_SS(AX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_FLAGS(AX) + PUSHQ PTRACE_CS(AX) + PUSHQ PTRACE_RIP(AX) + PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack. + PUSHQ CX // Save userCR3 on kernel stack. + REGISTERS_LOAD(AX, 0) // Restore most registers. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. IRET() // See entry_amd64.go. TEXT ·resume(SB),NOSPLIT,$0 // See iret, above. - MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX - MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX - MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI - MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI - REGISTERS_LOAD(GS, CPU_REGISTERS) - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + PUSHQ CPU_REGISTERS+PTRACE_SS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RSP(AX) + PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX) + PUSHQ CPU_REGISTERS+PTRACE_CS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RIP(AX) + REGISTERS_LOAD(AX, CPU_REGISTERS) + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX IRET() // See entry_amd64.go. TEXT ·Start(SB),NOSPLIT,$0 - LOAD_KERNEL_STACK(AX) // Set the stack. PUSHQ $0x0 // Previous frame pointer. MOVQ SP, BP // Set frame pointer. PUSHQ AX // First argument (CPU). @@ -155,53 +193,60 @@ TEXT ·Start(SB),NOSPLIT,$0 // See entry_amd64.go. TEXT ·sysenter(SB),NOSPLIT,$0 - // Interrupts are always disabled while we're executing in kernel mode - // and always enabled while executing in user mode. Therefore, we can - // reliably look at the flags in R11 to determine where this syscall - // was from. - TESTL $_RFLAGS_IF, R11 + // _RFLAGS_IOPL0 is always set in the user mode and it is never set in + // the kernel mode. See the comment of UserFlagsSet for more details. + TESTL $_RFLAGS_IOPL0, R11 JZ kernel - user: SWAP_GS() - XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks. - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs). + MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value. - MOVQ BX, PTRACE_RAX(AX) // Save everything else. - MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ CX, PTRACE_RIP(AX) MOVQ R11, PTRACE_FLAGS(AX) - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. + MOVQ SP, PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value. + MOVQ CX, PTRACE_RAX(AX) // Save everything else. + MOVQ CX, PTRACE_ORIGRAX(AX) + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks. + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. // Return to the kernel, where the frame is: // - // vector (sp+24) + // vector (sp+32) + // userCR3 (sp+24) // regs (sp+16) // cpu (sp+8) // vcpu.Switch (sp+0) // - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ $Syscall, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ $Syscall, 32(SP) // Output vector. RET kernel: // We can't restore the original stack, but we can access the registers // in the CPU state directly. No need for temporary juggling. - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ AX, ENTRY_SCRATCH0(GS) + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), BX + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. // Call the syscall trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ AX // First argument (vCPU). CALL ·kernelSyscall(SB) // Call the trampoline. POPQ AX // Pop vCPU. @@ -230,16 +275,21 @@ TEXT ·exception(SB),NOSPLIT,$0 // ERROR_CODE (sp+8) // VECTOR (sp+0) // - TESTL $_RFLAGS_IF, 32(SP) + TESTL $_RFLAGS_IOPL0, 32(SP) JZ kernel user: SWAP_GS() ADDQ $-8, SP // Adjust for flags. MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ). - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs. + PUSHQ AX // Save user AX on stack. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX. + POPQ BX // Restore original AX. MOVQ BX, PTRACE_RAX(AX) // Save it. MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX) @@ -249,34 +299,36 @@ user: MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX) // Copy out and return. + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. MOVQ 0(SP), BX // Load vector. MOVQ 8(SP), CX // Load error code. - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version). - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ CX, CPU_ERROR_CODE(GS) // Set error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. - MOVQ BX, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version). + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ CX, CPU_ERROR_CODE(AX) // Set error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. + MOVQ BX, 32(SP) // Output vector. RET kernel: // As per above, we can save directly. - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS) + PUSHQ AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + POPQ BX + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX) // Set the error code and adjust the stack. - MOVQ 8(SP), AX // Load the error code. - MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ 8(SP), BX // Load the error code. + MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. MOVQ 0(SP), BX // BX contains the vector. - ADDQ $48, SP // Drop the exception frame. // Call the exception trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ BX // Second argument (vector). PUSHQ AX // First argument (vCPU). CALL ·kernelException(SB) // Call the trampoline. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2bc5f3ecd..f3d934996 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -27,7 +27,9 @@ // ERET returns using the ELR and SPSR for the current exception level. #define ERET() \ - WORD $0xd69f03e0 + WORD $0xd69f03e0; \ + DSB $7; \ + ISB $15; // RSV_REG is a register that holds el1 information temporarily. #define RSV_REG R18_PLATFORM @@ -40,6 +42,20 @@ #define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) +// sctlr_el1: system control register el1. +#define SCTLR_M 1 << 0 +#define SCTLR_C 1 << 2 +#define SCTLR_I 1 << 12 +#define SCTLR_UCT 1 << 15 + +#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT) + +// cntkctl_el1: counter-timer kernel control register el1. +#define CNTKCTL_EL0PCTEN 1 << 0 +#define CNTKCTL_EL0VCTEN 1 << 1 + +#define CNTKCTL_EL1_DEFAULT (CNTKCTL_EL0PCTEN | CNTKCTL_EL0VCTEN) + // Saves a register set. // // This is a macro because it may need to executed in contents where a stack is @@ -286,23 +302,23 @@ // SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application. #define SWITCH_TO_APP_PAGETABLE(from) \ - MOVD CPU_TTBR0_APP(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD CPU_APP_ASID(from), R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set) ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; // SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable. #define SWITCH_TO_KVM_PAGETABLE(from) \ - MOVD CPU_TTBR0_KVM(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD $1, R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ ISB $15; \ - DSB $15; - -#define IRQ_ENABLE \ - MSR $2, DAIFSet; - -#define IRQ_DISABLE \ - MSR $2, DAIFClr; + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ @@ -318,23 +334,20 @@ #define KERNEL_ENTRY_FROM_EL0 \ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. - SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 - MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. - REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + WORD $0xd538d092; \ // MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step2, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step3, save app context. MOVD RSV_REG_APP, R20; \ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ ADD $16, RSP, RSP; \ MOVD RSV_REG, PTRACE_R18(R20); \ MOVD RSV_REG_APP, PTRACE_R9(R20); \ - MOVD R20, RSV_REG_APP; \ WORD $0xd5384003; \ // MRS SPSR_EL1, R3 - MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MOVD R3, PTRACE_PSTATE(R20); \ MRS ELR_EL1, R3; \ - MOVD R3, PTRACE_PC(RSV_REG_APP); \ + MOVD R3, PTRACE_PC(R20); \ WORD $0xd5384103; \ // MRS SP_EL0, R3 - MOVD R3, PTRACE_SP(RSV_REG_APP); + MOVD R3, PTRACE_SP(R20); // KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1. #define KERNEL_ENTRY_FROM_EL1 \ @@ -349,6 +362,13 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// storeAppASID writes the application's asid value. +TEXT ·storeAppASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + MRS TPIDR_EL1, RSV_REG + MOVD R1, CPU_APP_ASID(RSV_REG) + RET + // Halt halts execution. TEXT ·Halt(SB),NOSPLIT,$0 // Clear bluepill. @@ -356,6 +376,9 @@ TEXT ·Halt(SB),NOSPLIT,$0 CMP RSV_REG, R9 BNE mmio_exit MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG) + + // Flush dcache. + WORD $0xd5087e52 // DC CISW mmio_exit: // Disable fpsimd. WORD $0xd5381041 // MRS CPACR_EL1, R1 @@ -373,6 +396,9 @@ mmio_exit: MRS VBAR_EL1, R9 MOVD R0, 0x0(R9) + // Flush dcahce. + WORD $0xd5087e52 // DC CISW + RET // HaltAndResume halts execution and point the pointer to the resume function. @@ -400,12 +426,13 @@ TEXT ·Current(SB),NOSPLIT,$0-8 MOVD R8, ret+0(FP) RET -#define STACK_FRAME_SIZE 16 +#define STACK_FRAME_SIZE 32 // kernelExitToEl0 is the entrypoint for application in guest_el0. // Prepare the vcpu environment for container application. TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 // Step1, save sentry context into memory. + MRS TPIDR_EL1, RSV_REG REGISTERS_SAVE(RSV_REG, CPU_REGISTERS) MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG) @@ -417,34 +444,13 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3 - // Step2, save SP_EL1, PSTATE into kernel temporary stack. - // switch to temporary stack. + // Step2, switch to temporary stack. LOAD_KERNEL_STACK(RSV_REG) - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - - SUB $STACK_FRAME_SIZE, RSP, RSP - MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11 - MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12 - STP (R11, R12), 16*0(RSP) - MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11 - MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12 - - // Step3, test user pagetable. - // If user pagetable is empty, trapped in el1_ia. - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - SWITCH_TO_APP_PAGETABLE(RSV_REG) - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - SWITCH_TO_KVM_PAGETABLE(RSV_REG) - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - - // If pagetable is not empty, recovery kernel temporary stack. - ADD $STACK_FRAME_SIZE, RSP, RSP - - // Step4, load app context pointer. + // Step3, load app context pointer. MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP - // Step5, prepare the environment for container application. + // Step4, prepare the environment for container application. // set sp_el0. MOVD PTRACE_SP(RSV_REG_APP), R1 WORD $0xd5184101 //MSR R1, SP_EL0 @@ -455,6 +461,14 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 MOVD PTRACE_PSTATE(RSV_REG_APP), R1 WORD $0xd5184001 //MSR R1, SPSR_EL1 + // need use kernel space address to excute below code, since + // after SWITCH_TO_APP_PAGETABLE the ASID is changed to app's + // ASID. + WORD $0x10000061 // ADR R1, do_exit_to_el0 + ORR $0xffff000000000000, R1, R1 + JMP (R1) + +do_exit_to_el0: // RSV_REG & RSV_REG_APP will be loaded at the end. REGISTERS_LOAD(RSV_REG_APP, 0) @@ -464,11 +478,13 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 SUB $STACK_FRAME_SIZE, RSP, RSP STP (RSV_REG, RSV_REG_APP), 16*0(RSP) + STP (R0, R1), 16*1(RSP) WORD $0xd538d092 //MRS TPIDR_EL1, R18 SWITCH_TO_APP_PAGETABLE(RSV_REG) + LDP 16*1(RSP), (R0, R1) LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) ADD $STACK_FRAME_SIZE, RSP, RSP @@ -478,7 +494,6 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 // Prepare the vcpu environment for sentry. TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 WORD $0xd538d092 //MRS TPIDR_EL1, R18 - MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1 WORD $0xd5184001 //MSR R1, SPSR_EL1 @@ -488,6 +503,9 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP + SWITCH_TO_KVM_PAGETABLE(RSV_REG) + MRS TPIDR_EL1, RSV_REG + REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP @@ -495,7 +513,15 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 - IRQ_DISABLE + // Flush dcache. + WORD $0xd5087e52 // DC CISW + // Init. + MOVD $SCTLR_EL1_DEFAULT, R1 + MSR R1, SCTLR_EL1 + + MOVD $CNTKCTL_EL1_DEFAULT, R1 + MSR R1, CNTKCTL_EL1 + MOVD R8, RSV_REG ORR $0xffff000000000000, RSV_REG, RSV_REG WORD $0xd518d092 //MSR R18, TPIDR_EL1 @@ -544,6 +570,7 @@ TEXT ·El1_sync(SB),NOSPLIT,$0 B el1_invalid el1_da: +el1_ia: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -556,9 +583,6 @@ el1_da: B ·HaltAndResume(SB) -el1_ia: - B ·HaltAndResume(SB) - el1_sp_pc: B ·Shutdown(SB) @@ -630,9 +654,10 @@ el0_svc: MOVD $Syscall, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + B ·kernelExitToEl1(SB) el0_da: +el0_ia: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -644,10 +669,10 @@ el0_da: MOVD $PageFault, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + MRS ESR_EL1, R3 + MOVD R3, CPU_ERROR_CODE(RSV_REG) -el0_ia: - B ·Shutdown(SB) + B ·kernelExitToEl1(SB) el0_fpsimd_acc: B ·Shutdown(SB) @@ -662,7 +687,10 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - B ·Shutdown(SB) + MOVD $El0Sync_undef, R3 + MOVD R3, CPU_VECTOR_CODE(RSV_REG) + + B ·kernelExitToEl1(SB) el0_dbg: B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 549f3d228..9742308d8 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,7 +24,10 @@ go_binary( "defs_impl_arm64.go", "main.go", ], - visibility = ["//pkg/sentry/platform/ring0:__pkg__"], + visibility = [ + "//pkg/sentry/platform/kvm:__pkg__", + "//pkg/sentry/platform/ring0:__pkg__", + ], deps = [ "//pkg/cpuid", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 021693791..264be23d3 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -19,8 +19,8 @@ package ring0 // N.B. that constraints on KernelOpts must be satisfied. // //go:nosplit -func (k *Kernel) Init(opts KernelOpts) { - k.init(opts) +func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { + k.init(opts, maxCPUs) } // Halt halts execution. @@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) { // kernelSyscall is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) { // kernelException is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) { // Init initializes a new CPU. // // Init allows embedding in other objects. -func (c *CPU) Init(k *Kernel, hooks Hooks) { - c.self = c // Set self reference. - c.kernel = k // Set kernel reference. - c.init() // Perform architectural init. +func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) { + c.self = c // Set self reference. + c.kernel = k // Set kernel reference. + c.init(cpuID) // Perform architectural init. // Require hooks. if hooks != nil { diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index d37981dbf..3a9dff4cc 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -18,13 +18,42 @@ package ring0 import ( "encoding/binary" + "reflect" + + "gvisor.dev/gvisor/pkg/usermem" ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables + entrySize := reflect.TypeOf(kernelEntry{}).Size() + var ( + entries []kernelEntry + padding = 1 + ) + for { + entries = make([]kernelEntry, maxCPUs+padding-1) + totalSize := entrySize * uintptr(maxCPUs+padding-1) + addr := reflect.ValueOf(&entries[0]).Pointer() + if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize { + // The runtime forces power-of-2 alignment for allocations, and we are therefore + // safe once the first address is aligned and the chunk is at least a full page. + break + } + padding = padding << 1 + } + k.cpuEntries = entries + + k.globalIDT = &idt64{} + if reflect.TypeOf(idt64{}).Size() != usermem.PageSize { + panic("Size of globalIDT should be PageSize") + } + if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 { + panic("Allocated globalIDT should be page aligned") + } + // Setup the IDT, which is uniform. for v, handler := range handlers { // Allow Breakpoint and Overflow to be called from all @@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) { } } +func (k *Kernel) EntryRegions() map[uintptr]uintptr { + regions := make(map[uintptr]uintptr) + + addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer() + size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries)) + end, _ := usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + addr = reflect.ValueOf(k.globalIDT).Pointer() + size = reflect.TypeOf(idt64{}).Size() + end, _ = usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + return regions +} + // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { + c.kernelEntry = &c.kernel.cpuEntries[cpuID] + c.cpuSelf = c // Null segment. c.gdt[0].setNull() @@ -65,6 +112,7 @@ func (c *CPU) init() { // Set the kernel stack pointer in the TSS (virtual address). stackAddr := c.StackTop() + c.stackTop = stackAddr c.tss.rsp0Lo = uint32(stackAddr) c.tss.rsp0Hi = uint32(stackAddr >> 32) c.tss.ist1Lo = uint32(stackAddr) @@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID) - kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID) + c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)) // Sanitize registers. regs := switchOpts.Registers @@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point. - jumpToKernel() // Switch to upper half. - writeCR3(uintptr(userCR3)) // Change to user address space. if switchOpts.FullRestore { - vector = iret(c, regs) + vector = iret(c, regs, uintptr(userCR3)) } else { - vector = sysret(c, regs) + vector = sysret(c, regs, uintptr(userCR3)) } - writeCR3(uintptr(kernelCR3)) // Return to kernel address space. - jumpToUser() // Return to lower half. SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return @@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { //go:nosplit func start(c *CPU) { // Save per-cpu & FS segment. - WriteGS(kernelAddr(c)) + WriteGS(kernelAddr(c.kernelEntry)) WriteFS(uintptr(c.registers.Fs_base)) // Initialize floating point. diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index ccacaea6b..0ca98a7c7 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,13 +25,13 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables } // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { // Set the kernel stack pointer(virtual address). c.registers.Sp = uint64(c.StackTop()) @@ -53,14 +53,25 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { - // Sanitize registers. + storeAppASID(uintptr(switchOpts.UserASID)) + if switchOpts.Flush { + FlushTlbAll() + } + regs := switchOpts.Registers - regs.Pstate &= ^uint64(UserFlagsClear) + regs.Pstate &= ^uint64(PsrFlagsClear) regs.Pstate |= UserFlagsSet + + LoadFloatingPoint(switchOpts.FloatingPointState) + SetTLS(regs.TPIDR_EL0) + kernelExitToEl0() + + regs.TPIDR_EL0 = GetTLS() + SaveFloatingPoint(switchOpts.FloatingPointState) + vector = c.vecCode - // Perform the switch. return } diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go index ca968a036..0ec5c3bc5 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.go +++ b/pkg/sentry/platform/ring0/lib_amd64.go @@ -61,21 +61,9 @@ func wrgsbase(addr uintptr) // wrgsmsr writes to the GS_BASE MSR. func wrgsmsr(addr uintptr) -// writeCR3 writes the CR3 value. -func writeCR3(phys uintptr) - -// readCR3 reads the current CR3 value. -func readCR3() uintptr - // readCR2 reads the current CR2 value. func readCR2() uintptr -// jumpToKernel jumps to the kernel version of the current RIP. -func jumpToKernel() - -// jumpToUser jumps to the user version of the current RIP. -func jumpToUser() - // fninit initializes the floating point unit. func fninit() diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s index 75d742750..2fe83568a 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.s +++ b/pkg/sentry/platform/ring0/lib_amd64.s @@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8 BYTE $0x0f; BYTE $0x30; // WRMSR RET -// jumpToUser changes execution to the user address. -// -// This works by changing the return value to the user version. -TEXT ·jumpToUser(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - NOTQ BX - ANDQ BX, SP // Switch the stack. - ANDQ BX, BP // Switch the frame pointer. - ANDQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// jumpToKernel changes execution to the kernel address space. -// -// This works by changing the return value to the kernel version. -TEXT ·jumpToKernel(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - ORQ BX, SP // Switch the stack. - ORQ BX, BP // Switch the frame pointer. - ORQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// writeCR3 writes the given CR3 value. -// -// The code corresponds to: -// -// mov %rax, %cr3 -// -TEXT ·writeCR3(SB),NOSPLIT,$0-8 - MOVQ cr3+0(FP), AX - BYTE $0x0f; BYTE $0x22; BYTE $0xd8; - RET - -// readCR3 reads the current CR3 value. -// -// The code corresponds to: -// -// mov %cr3, %rax -// -TEXT ·readCR3(SB),NOSPLIT,$0-8 - BYTE $0x0f; BYTE $0x20; BYTE $0xd8; - MOVQ AX, ret+0(FP) - RET - // readCR2 reads the current CR2 value. // // The code corresponds to: diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index a6345010d..2f1abcb0f 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -16,6 +16,15 @@ package ring0 +// storeAppASID writes the application's asid value. +func storeAppASID(asid uintptr) + +// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. +func LocalFlushTlbAll() + +// FlushTlbAll flush all tlb. +func FlushTlbAll() + // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) @@ -38,6 +47,12 @@ func SaveVRegs(*byte) // LoadVRegs loads V0-V31 registers. func LoadVRegs(*byte) +// LoadFloatingPoint loads floating point state. +func LoadFloatingPoint(*byte) + +// SaveFloatingPoint saves floating point state. +func SaveFloatingPoint(*byte) + // GetTLS returns the value of TPIDR_EL0 register. func GetTLS() (value uint64) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index b63e14b41..8aabf7d0e 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,20 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 + DSB $6 // dsb(nshst) + WORD $0xd508871f // __tlbi(vmalle1) + DSB $7 // dsb(nsh) + ISB $15 + RET + +TEXT ·FlushTlbAll(SB),NOSPLIT,$0 + DSB $10 // dsb(ishst) + WORD $0xd508831f // __tlbi(vmalle1is) + DSB $11 // dsb(ish) + ISB $15 + RET + TEXT ·GetTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 MOVD R1, ret+0(FP) @@ -129,3 +143,89 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 ISB $15 RET + +TEXT ·LoadFloatingPoint(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + MOVD 0(R0), R1 + MOVD R1, FPSR + MOVD 8(R0), R1 + MOVD R1, NZCV + + FMOVD 16*1(R0), F0 + FMOVD 16*2(R0), F1 + FMOVD 16*3(R0), F2 + FMOVD 16*4(R0), F3 + FMOVD 16*5(R0), F4 + FMOVD 16*6(R0), F5 + FMOVD 16*7(R0), F6 + FMOVD 16*8(R0), F7 + FMOVD 16*9(R0), F8 + FMOVD 16*10(R0), F9 + FMOVD 16*11(R0), F10 + FMOVD 16*12(R0), F11 + FMOVD 16*13(R0), F12 + FMOVD 16*14(R0), F13 + FMOVD 16*15(R0), F14 + FMOVD 16*16(R0), F15 + FMOVD 16*17(R0), F16 + FMOVD 16*18(R0), F17 + FMOVD 16*19(R0), F18 + FMOVD 16*20(R0), F19 + FMOVD 16*21(R0), F20 + FMOVD 16*22(R0), F21 + FMOVD 16*23(R0), F22 + FMOVD 16*24(R0), F23 + FMOVD 16*25(R0), F24 + FMOVD 16*26(R0), F25 + FMOVD 16*27(R0), F26 + FMOVD 16*28(R0), F27 + FMOVD 16*29(R0), F28 + FMOVD 16*30(R0), F29 + FMOVD 16*31(R0), F30 + FMOVD 16*32(R0), F31 + + RET + +TEXT ·SaveFloatingPoint(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + MOVD FPSR, R1 + MOVD R1, 0(R0) + MOVD NZCV, R1 + MOVD R1, 8(R0) + + FMOVD F0, 16*1(R0) + FMOVD F1, 16*2(R0) + FMOVD F2, 16*3(R0) + FMOVD F3, 16*4(R0) + FMOVD F4, 16*5(R0) + FMOVD F5, 16*6(R0) + FMOVD F6, 16*7(R0) + FMOVD F7, 16*8(R0) + FMOVD F8, 16*9(R0) + FMOVD F9, 16*10(R0) + FMOVD F10, 16*11(R0) + FMOVD F11, 16*12(R0) + FMOVD F12, 16*13(R0) + FMOVD F13, 16*14(R0) + FMOVD F14, 16*15(R0) + FMOVD F15, 16*16(R0) + FMOVD F16, 16*17(R0) + FMOVD F17, 16*18(R0) + FMOVD F18, 16*19(R0) + FMOVD F19, 16*20(R0) + FMOVD F20, 16*21(R0) + FMOVD F21, 16*22(R0) + FMOVD F22, 16*23(R0) + FMOVD F23, 16*24(R0) + FMOVD F24, 16*25(R0) + FMOVD F25, 16*26(R0) + FMOVD F26, 16*27(R0) + FMOVD F27, 16*28(R0) + FMOVD F28, 16*29(R0) + FMOVD F29, 16*30(R0) + FMOVD F30, 16*31(R0) + FMOVD F31, 16*32(R0) + + RET diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go index b8ab120a0..290d94bd6 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/sentry/platform/ring0/offsets_amd64.go @@ -30,14 +30,22 @@ func Emit(w io.Writer) { c := &CPU{} fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + + e := &kernelEntry{} + fmt.Fprintf(w, "\n// CPU entry offsets.\n") + fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_SCRATCH1 0x%02x\n", reflect.ValueOf(&e.scratch1).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0) fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index f3de962f0..1d86b4bcf 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -41,6 +41,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ASID 0x%02x\n", reflect.ValueOf(&c.appASID).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 78510ebed..6409d1d91 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -72,13 +72,14 @@ const ( ) const ( - mtNormal = 0x4 << 2 + mtDevicenGnRE = 0x1 << 2 + mtNormal = 0x4 << 2 ) const ( executeDisable = xn optionMask = 0xfff | 0xfff<<48 - protDefault = accessed | shared | mtNormal + protDefault = accessed | shared ) // MapOpts are x86 options. @@ -184,8 +185,10 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { if opts.User { v |= user + v |= mtNormal } else { v = v &^ user + v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress. } atomic.StoreUintptr((*uintptr)(p), v) } @@ -200,7 +203,7 @@ func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) { // This should never happen. panic("unaligned physical address!") } - v := addr | typeTable | protDefault + v := addr | typeTable | protDefault | mtNormal atomic.StoreUintptr((*uintptr)(p), v) } diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go index 9da0ea685..34fbc1c35 100644 --- a/pkg/sentry/platform/ring0/x86.go +++ b/pkg/sentry/platform/ring0/x86.go @@ -39,7 +39,9 @@ const ( _RFLAGS_AC = 1 << 18 _RFLAGS_NT = 1 << 14 - _RFLAGS_IOPL = 3 << 12 + _RFLAGS_IOPL0 = 1 << 12 + _RFLAGS_IOPL1 = 1 << 13 + _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1 _RFLAGS_DF = 1 << 10 _RFLAGS_IF = 1 << 9 _RFLAGS_STEP = 1 << 8 @@ -67,15 +69,45 @@ const ( KernelFlagsSet = _RFLAGS_RESERVED // UserFlagsSet are always set in userspace. - UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF + // + // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege + // level. The Current Privilege Level (CPL) of the task must be less + // than or equal to the IOPL in order for the task or program to access + // I/O ports. + // + // Here, _RFLAGS_IOPL0 is used only to determine whether the task is + // running in the kernel or userspace mode. In the user mode, the CPL is + // always 3 and it doesn't matter what IOPL is set if it is bellow CPL. + // + // We need to have one bit which will be always different in user and + // kernel modes. And we have to remember that even though we have + // KernelFlagsClear, we still can see some of these flags in the kernel + // mode. This can happen when the goruntime switches on a goroutine + // which has been saved in the host mode. On restore, the popf + // instruction is used to restore flags and this means that all flags + // what the goroutine has in the host mode will be restored in the + // kernel mode. + // + // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set + // it in the user mode. So if this flag is set, the task is running in + // the user mode and if it isn't set, the task is running in the kernel + // mode. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0 // KernelFlagsClear should always be clear in the kernel. KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT // UserFlagsClear are always cleared in userspace. - UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1 ) +// IsKernelFlags returns true if rflags coresponds to the kernel mode. +// +// go:nosplit +func IsKernelFlags(rflags uint64) bool { + return rflags&_RFLAGS_IOPL0 == 0 +} + // Vector is an exception vector. type Vector uintptr @@ -104,7 +136,7 @@ const ( VirtualizationException SecurityException = 0x1e SyscallInt80 = 0x80 - _NR_INTERRUPTS = SyscallInt80 + 1 + _NR_INTERRUPTS = 0x100 ) // System call vectors. diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..a3f775d15 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/marshal", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 8b439a078..70ccf77a7 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -68,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { for _, fd := range fds { file := t.GetFile(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -100,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFiles) Release() { +func (fs *RightsFiles) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -115,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index fd08179be..d9621968c 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -46,7 +46,7 @@ func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { for _, fd := range fds { file := t.GetFileVFS2(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -78,9 +78,9 @@ func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFilesVFS2) Release() { +func (fs *RightsFilesVFS2) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -93,7 +93,7 @@ func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index ff81ea6e6..b6ebe29d6 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -21,6 +21,8 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", @@ -37,6 +39,9 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a92aed2c9..7d3c4a01c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -24,6 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -98,12 +100,12 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc return nil, syserr.FromError(err) } dirent := socket.NewDirent(ctx, socketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) syscall.Close(s.fd) } @@ -267,7 +269,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, syscall.Close(fd) return 0, nil, 0, err } - defer f.DecRef() + defer f.DecRef(t) kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, @@ -279,7 +281,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, syscall.Close(fd) return 0, nil, 0, err } - defer f.DecRef() + defer f.DecRef(t) kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, @@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. @@ -708,6 +711,6 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) - socket.RegisterProviderVFS2(family, &socketProviderVFS2{}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{family}) } } diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 8f192c62f..163af329b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -52,6 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) s := &socketVFS2{ socketOpsCommon: socketOpsCommon{ @@ -71,11 +72,19 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in DenyPWrite: true, UseDentryMetadata: true, }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) return nil, syserr.FromError(err) } return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *socketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) @@ -96,11 +105,6 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// Allocate implements vfs.FileDescriptionImpl.Allocate. -func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { - return syserror.ENODEV -} - // PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index a48082631..faa61160e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -30,6 +30,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -53,11 +56,14 @@ type Stack struct { interfaceAddrs map[int32][]inet.InterfaceAddr routes []inet.Route supportsIPv6 bool + tcpRecovery inet.TCPLossRecovery tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool netDevFile *os.File netSNMPFile *os.File + ipv4Forwarding bool + ipv6Forwarding bool } // NewStack returns an empty Stack containing no configuration. @@ -117,6 +123,13 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } + s.ipv6Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil { + s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + } else { + log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false") + } + return nil } @@ -350,6 +363,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + return s.tcpRecovery, nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserror.EACCES +} + // getLine reads one line from proc file, with specified prefix. // The last argument, withHeader, specifies if it contains line header. func getLine(f *os.File, prefix string, withHeader bool) string { @@ -457,3 +480,21 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber: + return s.ipv4Forwarding + case ipv6.ProtocolNumber: + return s.ipv6Forwarding + default: + log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) + return false + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 721094bbf..8aea0200f 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -6,6 +6,8 @@ go_library( name = "netfilter", srcs = [ "extensions.go", + "ipv4.go", + "ipv6.go", "netfilter.go", "owner_matcher.go", "targets.go", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 0336a32d8..549787955 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -37,7 +39,7 @@ type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. name() string - // marshal converts from an stack.Matcher to an ABI struct. + // marshal converts from a stack.Matcher to an ABI struct. marshal(matcher stack.Matcher) []byte // unmarshal converts from the ABI matcher struct to an @@ -93,3 +95,71 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf } return matchMaker.unmarshal(buf, filter) } + +// targetMaker knows how to (un)marshal a target. Once registered, +// marshalTarget and unmarshalTarget can be used. +type targetMaker interface { + // id uniquely identifies the target. + id() stack.TargetID + + // marshal converts from a stack.Target to an ABI struct. + marshal(target stack.Target) []byte + + // unmarshal converts from the ABI matcher struct to a stack.Target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) +} + +// targetMakers maps the TargetID of supported targets to the targetMaker that +// marshals and unmarshals it. It is immutable after package initialization. +var targetMakers = map[stack.TargetID]targetMaker{} + +func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { + tid := stack.TargetID{ + Name: name, + NetworkProtocol: netProto, + Revision: rev, + } + if _, ok := targetMakers[tid]; !ok { + return 0, false + } + + // Return the highest supported revision unless rev is higher. + for _, other := range targetMakers { + otherID := other.id() + if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { + rev = uint8(otherID.Revision) + } + } + return rev, true +} + +// registerTargetMaker should be called by target extensions to register them +// with the netfilter package. +func registerTargetMaker(tm targetMaker) { + if _, ok := targetMakers[tm.id()]; ok { + panic(fmt.Sprintf("multiple targets registered with name %q.", tm.id())) + } + targetMakers[tm.id()] = tm +} + +func marshalTarget(target stack.Target) []byte { + targetMaker, ok := targetMakers[target.ID()] + if !ok { + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + } + return targetMaker.marshal(target) +} + +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { + tid := stack.TargetID{ + Name: target.Name.String(), + NetworkProtocol: filter.NetworkProtocol(), + Revision: target.Revision, + } + targetMaker, ok := targetMakers[tid] + if !ok { + nflog("unsupported target with name %q", target.Name.String()) + return nil, syserr.ErrInvalidArgument + } + return targetMaker.unmarshal(buf, filter) +} diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go new file mode 100644 index 000000000..b560fae0d --- /dev/null +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -0,0 +1,265 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// emptyIPv4Filter is for comparison with a rule's filters to determine whether +// it is also empty. It is immutable. +var emptyIPv4Filter = stack.IPHeaderFilter{ + Dst: "\x00\x00\x00\x00", + DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", +} + +// convertNetstackToBinary4 converts the iptables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a little data, reading some +// offsets, jumping to those offsets, parsing again, etc. +func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + } + + table, ok := stk.IPTables().GetTable(tablename.String(), false) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + // Setup the info struct. + entries, info := getEntries4(table, tablename) + return entries, info, nil +} + +func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo) { + var info linux.IPTGetinfo + var entries linux.KernelIPTGetEntries + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], info.Name[:]) + info.ValidHooks = table.ValidHooks() + + for ruleIdx, rule := range table.Rules { + nflog("convert to binary: current offset: %d", entries.Size) + + setHooksAndUnderflow(&info, table, entries.Size, ruleIdx) + // Each rule corresponds to an entry. + entry := linux.KernelIPTEntry{ + Entry: linux.IPTEntry{ + IP: linux.IPTIP{ + Protocol: uint16(rule.Filter.Protocol), + }, + NextOffset: linux.SizeOfIPTEntry, + TargetOffset: linux.SizeOfIPTEntry, + }, + } + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } + if rule.Filter.OutputInterfaceInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + nflog("convert to binary: matcher serialized as: %v", serialized) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + + nflog("convert to binary: adding entry: %+v", entry) + + entries.Size += uint32(entry.Entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + info.NumEntries++ + } + + info.Size = entries.Size + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + return entries, info +} + +func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { + nflog("set entries: setting entries in table %q", replace.Name.String()) + + // Convert input into a list of rules and their offsets. + var offset uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + nflog("set entries: processing entry at offset %d", offset) + + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIPTEntry { + nflog("optVal has insufficient size for entry %d", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + var entry linux.IPTEntry + buf := optVal[:linux.SizeOfIPTEntry] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + initialOptValLen := len(optVal) + optVal = optVal[linux.SizeOfIPTEntry:] + + if entry.TargetOffset < linux.SizeOfIPTEntry { + nflog("entry has too-small target offset %d", entry.TargetOffset) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support more IPTIP + // filtering fields. + filter, err := filterFromIPTIP(entry.IP) + if err != nil { + nflog("bad iptip: %v", err) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certain protocols, hooks, tables. + // Get matchers. + matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry + if len(optVal) < int(matchersSize) { + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + matchers, err := parseMatchers(filter, optVal[:matchersSize]) + if err != nil { + nflog("failed to parse matchers: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[matchersSize:] + + // Get the target of the rule. + targetSize := entry.NextOffset - entry.TargetOffset + if len(optVal) < int(targetSize) { + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + + rule := stack.Rule{ + Filter: filter, + Matchers: matchers, + } + + { + target, err := parseTarget(filter, optVal[:targetSize], false /* ipv6 */) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, err + } + rule.Target = target + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, rule) + offsets[offset] = int(entryIdx) + offset += uint32(entry.NextOffset) + + if initialOptValLen-len(optVal) != int(entry.NextOffset) { + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return nil, syserr.ErrInvalidArgument + } + } + return offsets, nil +} + +func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { + if containsUnsupportedFields4(iptip) { + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + } + if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + + return stack.IPHeaderFilter{ + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + // A Protocol value of 0 indicates all protocols match. + CheckProtocol: iptip.Protocol != 0, + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, + }, nil +} + +func containsUnsupportedFields4(iptip linux.IPTIP) bool { + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - Src and SrcMask + // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. + var emptyInterface = [linux.IFNAMSIZ]byte{} + // Disable any supported inverse flags. + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || + iptip.InputInterfaceMask != emptyInterface || + iptip.Flags != 0 || + iptip.InverseFlags&^inverseMask != 0 +} diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go new file mode 100644 index 000000000..4253f7bf4 --- /dev/null +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -0,0 +1,270 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// emptyIPv6Filter is for comparison with a rule's filters to determine whether +// it is also empty. It is immutable. +var emptyIPv6Filter = stack.IPHeaderFilter{ + Dst: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + DstMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", +} + +// convertNetstackToBinary6 converts the ip6tables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a little data, reading some +// offsets, jumping to those offsets, parsing again, etc. +func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo, error) { + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + } + + table, ok := stk.IPTables().GetTable(tablename.String(), true) + if !ok { + return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + // Setup the info struct, which is the same in IPv4 and IPv6. + entries, info := getEntries6(table, tablename) + return entries, info, nil +} + +func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo) { + var info linux.IPTGetinfo + var entries linux.KernelIP6TGetEntries + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], info.Name[:]) + info.ValidHooks = table.ValidHooks() + + for ruleIdx, rule := range table.Rules { + nflog("convert to binary: current offset: %d", entries.Size) + + setHooksAndUnderflow(&info, table, entries.Size, ruleIdx) + // Each rule corresponds to an entry. + entry := linux.KernelIP6TEntry{ + Entry: linux.IP6TEntry{ + IPv6: linux.IP6TIP{ + Protocol: uint16(rule.Filter.Protocol), + }, + NextOffset: linux.SizeOfIP6TEntry, + TargetOffset: linux.SizeOfIP6TEntry, + }, + } + copy(entry.Entry.IPv6.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IPv6.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IPv6.Src[:], rule.Filter.Src) + copy(entry.Entry.IPv6.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IPv6.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IPv6.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_SRCIP + } + if rule.Filter.OutputInterfaceInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_VIA_OUT + } + if rule.Filter.CheckProtocol { + entry.Entry.IPv6.Flags |= linux.IP6T_F_PROTO + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + nflog("convert to binary: matcher serialized as: %v", serialized) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + + nflog("convert to binary: adding entry: %+v", entry) + + entries.Size += uint32(entry.Entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + info.NumEntries++ + } + + info.Size = entries.Size + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + return entries, info +} + +func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { + nflog("set entries: setting entries in table %q", replace.Name.String()) + + // Convert input into a list of rules and their offsets. + var offset uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + nflog("set entries: processing entry at offset %d", offset) + + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIP6TEntry { + nflog("optVal has insufficient size for entry %d", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + var entry linux.IP6TEntry + buf := optVal[:linux.SizeOfIP6TEntry] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + initialOptValLen := len(optVal) + optVal = optVal[linux.SizeOfIP6TEntry:] + + if entry.TargetOffset < linux.SizeOfIP6TEntry { + nflog("entry has too-small target offset %d", entry.TargetOffset) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support more IPTIP + // filtering fields. + filter, err := filterFromIP6TIP(entry.IPv6) + if err != nil { + nflog("bad iptip: %v", err) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certain protocols, hooks, tables. + // Get matchers. + matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry + if len(optVal) < int(matchersSize) { + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + matchers, err := parseMatchers(filter, optVal[:matchersSize]) + if err != nil { + nflog("failed to parse matchers: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[matchersSize:] + + // Get the target of the rule. + targetSize := entry.NextOffset - entry.TargetOffset + if len(optVal) < int(targetSize) { + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + + rule := stack.Rule{ + Filter: filter, + Matchers: matchers, + } + + { + target, err := parseTarget(filter, optVal[:targetSize], true /* ipv6 */) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, err + } + rule.Target = target + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, rule) + offsets[offset] = int(entryIdx) + offset += uint32(entry.NextOffset) + + if initialOptValLen-len(optVal) != int(entry.NextOffset) { + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return nil, syserr.ErrInvalidArgument + } + } + return offsets, nil +} + +func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { + if containsUnsupportedFields6(iptip) { + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + } + if len(iptip.Dst) != header.IPv6AddressSize || len(iptip.DstMask) != header.IPv6AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } + if len(iptip.Src) != header.IPv6AddressSize || len(iptip.SrcMask) != header.IPv6AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + + return stack.IPHeaderFilter{ + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + // In ip6tables a flag controls whether to check the protocol. + CheckProtocol: iptip.Flags&linux.IP6T_F_PROTO != 0, + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IP6T_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, + }, nil +} + +func containsUnsupportedFields6(iptip linux.IP6TIP) bool { + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - Src and SrcMask + // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. + var emptyInterface = [linux.IFNAMSIZ]byte{} + flagMask := uint8(linux.IP6T_F_PROTO) + // Disable any supported inverse flags. + inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || + iptip.InputInterfaceMask != emptyInterface || + iptip.Flags&^flagMask != 0 || + iptip.InverseFlags&^inverseMask != 0 || + iptip.TOS != 0 +} diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f7abe77d3..904a12e38 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,7 +17,6 @@ package netfilter import ( - "bytes" "errors" "fmt" @@ -27,34 +26,15 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) -// errorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const errorTargetName = "ERROR" - -// redirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const redirectTargetName = "REDIRECT" - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. const enableLogging = false -// emptyFilter is for comparison with a rule's filters to determine whether it -// is also empty. It is immutable. -var emptyFilter = stack.IPHeaderFilter{ - Dst: "\x00\x00\x00\x00", - DstMask: "\x00\x00\x00\x00", - Src: "\x00\x00\x00\x00", - SrcMask: "\x00\x00\x00\x00", -} - // nflog logs messages related to the writing and reading of iptables. func nflog(format string, args ...interface{}) { if enableLogging && log.IsLogging(log.Debug) { @@ -63,14 +43,19 @@ func nflog(format string, args ...interface{}) { } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo - if _, err := t.CopyIn(outPtr, &info); err != nil { + if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } - _, info, err := convertNetstackToBinary(stack, info.Name) + var err error + if ipv6 { + _, info, err = convertNetstackToBinary6(stack, info.Name) + } else { + _, info, err = convertNetstackToBinary4(stack, info.Name) + } if err != nil { nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument @@ -80,18 +65,18 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT return info, nil } -// GetEntries returns netstack's iptables rules encoded for the iptables tool. -func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { +// GetEntries4 returns netstack's iptables rules. +func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries - if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + if _, err := userEntries.CopyIn(t, outPtr); err != nil { nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, _, err := convertNetstackToBinary(stack, userEntries.Name) + entries, _, err := convertNetstackToBinary4(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -104,236 +89,53 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -// convertNetstackToBinary converts the iptables as stored in netstack to the -// format expected by the iptables tool. Linux stores each table as a binary -// blob that can only be traversed by parsing a bit, reading some offsets, -// jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { - table, ok := stack.IPTables().GetTable(tablename.String()) - if !ok { - return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) - } - - var entries linux.KernelIPTGetEntries - var info linux.IPTGetinfo - info.ValidHooks = table.ValidHooks() - - // The table name has to fit in the struct. - if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) - } - copy(info.Name[:], tablename[:]) - copy(entries.Name[:], tablename[:]) - - for ruleIdx, rule := range table.Rules { - nflog("convert to binary: current offset: %d", entries.Size) - - // Is this a chain entry point? - for hook, hookRuleIdx := range table.BuiltinChains { - if hookRuleIdx == ruleIdx { - nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - info.HookEntry[hook] = entries.Size - } - } - // Is this a chain underflow point? - for underflow, underflowRuleIdx := range table.Underflows { - if underflowRuleIdx == ruleIdx { - nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - info.Underflow[underflow] = entries.Size - } - } - - // Each rule corresponds to an entry. - entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ - IP: linux.IPTIP{ - Protocol: uint16(rule.Filter.Protocol), - }, - NextOffset: linux.SizeOfIPTEntry, - TargetOffset: linux.SizeOfIPTEntry, - }, - } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) - if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP - } - if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP - } - if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT - } - - for _, matcher := range rule.Matchers { - // Serialize the matcher and add it to the - // entry. - serialized := marshalMatcher(matcher) - nflog("convert to binary: matcher serialized as: %v", serialized) - if len(serialized)%8 != 0 { - panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) - } - entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) - } - - // Serialize and append the target. - serialized := marshalTarget(rule.Target) - if len(serialized)%8 != 0 { - panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) - } - entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - - nflog("convert to binary: adding entry: %+v", entry) - - entries.Size += uint32(entry.NextOffset) - entries.Entrytable = append(entries.Entrytable, entry) - info.NumEntries++ - } - - nflog("convert to binary: finished with an marshalled size of %d", info.Size) - info.Size = entries.Size - return entries, info, nil -} - -func marshalTarget(target stack.Target) []byte { - switch tg := target.(type) { - case stack.AcceptTarget: - return marshalStandardTarget(stack.RuleAccept) - case stack.DropTarget: - return marshalStandardTarget(stack.RuleDrop) - case stack.ErrorTarget: - return marshalErrorTarget(errorTargetName) - case stack.UserChainTarget: - return marshalErrorTarget(tg.Name) - case stack.ReturnTarget: - return marshalStandardTarget(stack.RuleReturn) - case stack.RedirectTarget: - return marshalRedirectTarget(tg) - case JumpTarget: - return marshalJumpTarget(tg) - default: - panic(fmt.Errorf("unknown target of type %T", target)) - } -} - -func marshalStandardTarget(verdict stack.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target") - - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - Verdict: translateFromStandardVerdict(verdict), - } - - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalErrorTarget(errorName string) []byte { - // This is an error target named error - target := linux.XTErrorTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTErrorTarget, - }, - } - copy(target.Name[:], errorName) - copy(target.Target.Name[:], errorTargetName) - - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalRedirectTarget(rt stack.RedirectTarget) []byte { - // This is a redirect target named redirect - target := linux.XTRedirectTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTRedirectTarget, - }, +// GetEntries6 returns netstack's ip6tables rules. +func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) { + // Read in the struct and table name. IPv4 and IPv6 utilize structs + // with the same layout. + var userEntries linux.IPTGetEntries + if _, err := userEntries.CopyIn(t, outPtr); err != nil { + nflog("couldn't copy in entries %q", userEntries.Name) + return linux.KernelIP6TGetEntries{}, syserr.FromError(err) } - copy(target.Target.Name[:], redirectTargetName) - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) - target.NfRange.RangeSize = 1 - if rt.RangeProtoSpecified { - target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + // Convert netstack's iptables rules to something that the iptables + // tool can understand. + entries, _, err := convertNetstackToBinary6(stack, userEntries.Name) + if err != nil { + nflog("couldn't read entries: %v", err) + return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - // Convert port from little endian to big endian. - port := make([]byte, 2) - binary.LittleEndian.PutUint16(port, rt.MinPort) - target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) - binary.LittleEndian.PutUint16(port, rt.MaxPort) - target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalJumpTarget(jt JumpTarget) []byte { - nflog("convert to binary: marshalling jump target") - - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - // Verdict is overloaded by the ABI. When positive, it holds - // the jump offset from the start of the table. - Verdict: int32(jt.Offset), + if binary.Size(entries) > uintptr(outLen) { + nflog("insufficient GetEntries output size: %d", uintptr(outLen)) + return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + return entries, nil } -// translateFromStandardVerdict translates verdicts the same way as the iptables -// tool. -func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { - switch verdict { - case stack.RuleAccept: - return -linux.NF_ACCEPT - 1 - case stack.RuleDrop: - return -linux.NF_DROP - 1 - case stack.RuleReturn: - return linux.NF_RETURN - default: - // TODO(gvisor.dev/issue/170): Support Jump. - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) +// setHooksAndUnderflow checks whether the rule at ruleIdx is a hook entrypoint +// or underflow, in which case it fills in info.HookEntry and info.Underflows. +func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint32, ruleIdx int) { + // Is this a chain entry point? + for hook, hookRuleIdx := range table.BuiltinChains { + if hookRuleIdx == ruleIdx { + nflog("convert to binary: found hook %d at offset %d", hook, offset) + info.HookEntry[hook] = offset + } } -} - -// translateToStandardTarget translates from the value in a -// linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32) (stack.Target, error) { - // TODO(gvisor.dev/issue/170): Support other verdicts. - switch val { - case -linux.NF_ACCEPT - 1: - return stack.AcceptTarget{}, nil - case -linux.NF_DROP - 1: - return stack.DropTarget{}, nil - case -linux.NF_QUEUE - 1: - return nil, errors.New("unsupported iptables verdict QUEUE") - case linux.NF_RETURN: - return stack.ReturnTarget{}, nil - default: - return nil, fmt.Errorf("unknown iptables verdict %d", val) + // Is this a chain underflow point? + for underflow, underflowRuleIdx := range table.Underflows { + if underflowRuleIdx == ruleIdx { + nflog("convert to binary: found underflow %d at offset %d", underflow, offset) + info.Underflow[underflow] = offset + } } } // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { - // Get the basic rules data (struct ipt_replace). - if len(optVal) < linux.SizeOfIPTReplace { - nflog("optVal has insufficient size for replace %d", len(optVal)) - return syserr.ErrInvalidArgument - } +func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] @@ -342,88 +144,24 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument } - nflog("set entries: setting entries in table %q", replace.Name.String()) - - // Convert input into a list of rules and their offsets. - var offset uint32 - // offsets maps rule byte offsets to their position in table.Rules. - offsets := map[uint32]int{} - for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { - nflog("set entries: processing entry at offset %d", offset) - - // Get the struct ipt_entry. - if len(optVal) < linux.SizeOfIPTEntry { - nflog("optVal has insufficient size for entry %d", len(optVal)) - return syserr.ErrInvalidArgument - } - var entry linux.IPTEntry - buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, usermem.ByteOrder, &entry) - initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIPTEntry:] - - if entry.TargetOffset < linux.SizeOfIPTEntry { - nflog("entry has too-small target offset %d", entry.TargetOffset) - return syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/170): We should support more IPTIP - // filtering fields. - filter, err := filterFromIPTIP(entry.IP) - if err != nil { - nflog("bad iptip: %v", err) - return syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/170): Matchers and targets can specify - // that they only work for certain protocols, hooks, tables. - // Get matchers. - matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry - if len(optVal) < int(matchersSize) { - nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) - return syserr.ErrInvalidArgument - } - matchers, err := parseMatchers(filter, optVal[:matchersSize]) - if err != nil { - nflog("failed to parse matchers: %v", err) - return syserr.ErrInvalidArgument - } - optVal = optVal[matchersSize:] - - // Get the target of the rule. - targetSize := entry.NextOffset - entry.TargetOffset - if len(optVal) < int(targetSize) { - nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) - return syserr.ErrInvalidArgument - } - target, err := parseTarget(filter, optVal[:targetSize]) - if err != nil { - nflog("failed to parse target: %v", err) - return syserr.ErrInvalidArgument - } - optVal = optVal[targetSize:] - - table.Rules = append(table.Rules, stack.Rule{ - Filter: filter, - Target: target, - Matchers: matchers, - }) - offsets[offset] = int(entryIdx) - offset += uint32(entry.NextOffset) - - if initialOptValLen-len(optVal) != int(entry.NextOffset) { - nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) - return syserr.ErrInvalidArgument - } + var err *syserr.Error + var offsets map[uint32]int + if ipv6 { + offsets, err = modifyEntries6(stk, optVal, &replace, &table) + } else { + offsets, err = modifyEntries4(stk, optVal, &replace, &table) + } + if err != nil { + return err } // Go through the list of supported hooks for this table and, for each @@ -431,12 +169,14 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<<hook) != 0 { hk := hookFromLinux(hook) + table.BuiltinChains[hk] = stack.HookUnset + table.Underflows[hk] = stack.HookUnset for offset, ruleIdx := range offsets { if offset == replace.HookEntry[hook] { table.BuiltinChains[hk] = ruleIdx } if offset == replace.Underflow[hook] { - if !validUnderflow(table.Rules[ruleIdx]) { + if !validUnderflow(table.Rules[ruleIdx], ipv6) { nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) return syserr.ErrInvalidArgument } @@ -454,10 +194,9 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { } } - // Add the user chains. + // Check the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(stack.UserChainTarget) - if !ok { + if _, ok := rule.Target.(*stack.UserChainTarget); !ok { continue } @@ -473,13 +212,12 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { nflog("user chain's first node must have no matchers") return syserr.ErrInvalidArgument } - table.UserChains[target.Name] = ruleIdx + 1 } // Set each jump to point to the appropriate rule. Right now they hold byte // offsets. for ruleIdx, rule := range table.Rules { - jump, ok := rule.Target.(JumpTarget) + jump, ok := rule.Target.(*JumpTarget) if !ok { continue } @@ -499,8 +237,11 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook == stack.Forward || hook == stack.Postrouting { - if !isUnconditionalAccept(table.Rules[ruleIdx]) { + if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if ruleIdx == stack.HookUnset { + continue + } + if !isUnconditionalAccept(table.Rules[ruleIdx], ipv6) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } @@ -512,9 +253,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - stk.IPTables().ReplaceTable(replace.Name.String(), table) + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) - return nil } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -536,7 +276,6 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, // Check some invariants. if match.MatchSize < linux.SizeOfXTEntryMatch { - return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch) } if len(optVal) < int(match.MatchSize) { @@ -561,186 +300,26 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return matchers, nil } -// parseTarget parses a target from optVal. optVal should contain only the -// target. -func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { - nflog("set entries: parsing target of size %d", len(optVal)) - if len(optVal) < linux.SizeOfXTEntryTarget { - return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) - } - var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &target) - switch target.Name.String() { - case "": - // Standard target. - if len(optVal) != linux.SizeOfXTStandardTarget { - return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) - } - var standardTarget linux.XTStandardTarget - buf = optVal[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - - if standardTarget.Verdict < 0 { - // A Verdict < 0 indicates a non-jump verdict. - return translateToStandardTarget(standardTarget.Verdict) - } - // A verdict >= 0 indicates a jump. - return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil - - case errorTargetName: - // Error target. - if len(optVal) != linux.SizeOfXTErrorTarget { - return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) - } - var errorTarget linux.XTErrorTarget - buf = optVal[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) - - // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named errorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. - // * To mark the start of a user defined chain. These - // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case errorTargetName: - nflog("set entries: error target") - return stack.ErrorTarget{}, nil - default: - // User defined chain. - nflog("set entries: user-defined target %q", name) - return stack.UserChainTarget{Name: name}, nil - } - - case redirectTargetName: - // Redirect target. - if len(optVal) < linux.SizeOfXTRedirectTarget { - return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) - } - - if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - var redirectTarget linux.XTRedirectTarget - buf = optVal[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - - // Copy linux.XTRedirectTarget to stack.RedirectTarget. - var target stack.RedirectTarget - nfRange := redirectTarget.NfRange - - // RangeSize should be 1. - if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - // TODO(gvisor.dev/issue/170): Check if the flags are valid. - // Also check if we need to map ports or IP. - // For now, redirect target only supports destination port change. - // Port range and IP range are not supported yet. - if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - target.RangeProtoSpecified = true - - target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) - target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) - - // TODO(gvisor.dev/issue/170): Port range is not supported yet. - if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - // Convert port from big endian to little endian. - port := make([]byte, 2) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) - target.MinPort = binary.LittleEndian.Uint16(port) - - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) - target.MaxPort = binary.LittleEndian.Uint16(port) - return target, nil - } - - // Unknown target. - return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) -} - -func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { - if containsUnsupportedFields(iptip) { - return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) - } - if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { - return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) - } - if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { - return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) - } - - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - - return stack.IPHeaderFilter{ - Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), - Dst: tcpip.Address(iptip.Dst[:]), - DstMask: tcpip.Address(iptip.DstMask[:]), - DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, - Src: tcpip.Address(iptip.Src[:]), - SrcMask: tcpip.Address(iptip.SrcMask[:]), - SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, - OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, - }, nil -} - -func containsUnsupportedFields(iptip linux.IPTIP) bool { - // The following features are supported: - // - Protocol - // - Dst and DstMask - // - Src and SrcMask - // - The inverse destination IP check flag - // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} - // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags != 0 || - iptip.InverseFlags&^inverseMask != 0 -} - -func validUnderflow(rule stack.Rule) bool { +func validUnderflow(rule stack.Rule, ipv6 bool) bool { if len(rule.Matchers) != 0 { return false } - if rule.Filter != emptyFilter { + if (ipv6 && rule.Filter != emptyIPv6Filter) || (!ipv6 && rule.Filter != emptyIPv4Filter) { return false } switch rule.Target.(type) { - case stack.AcceptTarget, stack.DropTarget: + case *stack.AcceptTarget, *stack.DropTarget: return true default: return false } } -func isUnconditionalAccept(rule stack.Rule) bool { - if !validUnderflow(rule) { +func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { + if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(stack.AcceptTarget) + _, ok := rule.Target.(*stack.AcceptTarget) return ok } @@ -759,3 +338,20 @@ func hookFromLinux(hook int) stack.Hook { } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } + +// TargetRevision returns a linux.XTGetRevision for a given target. It sets +// Revision to the highest supported value, unless the provided revision number +// is larger. +func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) { + // Read in the target name and version. + var rev linux.XTGetRevision + if _, err := rev.CopyIn(t, revPtr); err != nil { + return linux.XTGetRevision{}, syserr.FromError(err) + } + maxSupported, ok := targetRevision(rev.Name.String(), netProto, rev.Revision) + if !ok { + return linux.XTGetRevision{}, syserr.ErrProtocolNotSupported + } + rev.Revision = maxSupported + return rev, nil +} diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index b91ba3ab3..0e14447fe 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,10 +15,359 @@ package netfilter import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) +func init() { + // Standard targets include ACCEPT, DROP, RETURN, and JUMP. + registerTargetMaker(&standardTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&standardTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) + + // Both user chains and actual errors are represented in iptables by + // error targets. + registerTargetMaker(&errorTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&errorTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) + + registerTargetMaker(&redirectTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&nfNATTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) +} + +type standardTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (sm *standardTargetMaker) id() stack.TargetID { + // Standard targets have the empty string as a name and no revisions. + return stack.TargetID{ + NetworkProtocol: sm.NetworkProtocol, + } +} +func (*standardTargetMaker) marshal(target stack.Target) []byte { + // Translate verdicts the same way as the iptables tool. + var verdict int32 + switch tg := target.(type) { + case *stack.AcceptTarget: + verdict = -linux.NF_ACCEPT - 1 + case *stack.DropTarget: + verdict = -linux.NF_DROP - 1 + case *stack.ReturnTarget: + verdict = linux.NF_RETURN + case *JumpTarget: + verdict = int32(tg.Offset) + default: + panic(fmt.Errorf("unknown target of type %T", target)) + } + + // The target's name will be the empty string. + xt := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + Verdict: verdict, + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) != linux.SizeOfXTStandardTarget { + nflog("buf has wrong size for standard target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + var standardTarget linux.XTStandardTarget + buf = buf[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict, filter.NetworkProtocol()) + } + // A verdict >= 0 indicates a jump. + return &JumpTarget{ + Offset: uint32(standardTarget.Verdict), + NetworkProtocol: filter.NetworkProtocol(), + }, nil +} + +type errorTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (em *errorTargetMaker) id() stack.TargetID { + // Error targets have no revision. + return stack.TargetID{ + Name: stack.ErrorTargetName, + NetworkProtocol: em.NetworkProtocol, + } +} + +func (*errorTargetMaker) marshal(target stack.Target) []byte { + var errorName string + switch tg := target.(type) { + case *stack.ErrorTarget: + errorName = stack.ErrorTargetName + case *stack.UserChainTarget: + errorName = tg.Name + default: + panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) + } + + // This is an error target named error + xt := linux.XTErrorTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTErrorTarget, + }, + } + copy(xt.Name[:], errorName) + copy(xt.Target.Name[:], stack.ErrorTargetName) + + ret := make([]byte, 0, linux.SizeOfXTErrorTarget) + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) != linux.SizeOfXTErrorTarget { + nflog("buf has insufficient size for error target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + var errorTarget linux.XTErrorTarget + buf = buf[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named stack.ErrorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch name := errorTarget.Name.String(); name { + case stack.ErrorTargetName: + return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + default: + // User defined chain. + return &stack.UserChainTarget{ + Name: name, + NetworkProtocol: filter.NetworkProtocol(), + }, nil + } +} + +type redirectTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *redirectTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*redirectTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) + // This is a redirect target named redirect + xt := linux.XTRedirectTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTRedirectTarget, + }, + } + copy(xt.Target.Name[:], stack.RedirectTargetName) + + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + xt.NfRange.RangeSize = 1 + xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) + xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) < linux.SizeOfXTRedirectTarget { + nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("redirectTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var redirectTarget linux.XTRedirectTarget + buf = buf[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + + // RangeSize should be 1. + nfRange := redirectTarget.NfRange + if nfRange.RangeSize != 1 { + nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("redirectTargetMaker: invalid range flags %d", nfRange.RangeIPV4.Flags) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP { + nflog("redirectTargetMaker: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + + target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.Port = ntohs(nfRange.RangeIPV4.MinPort) + + return &target, nil +} + +type nfNATTarget struct { + Target linux.XTEntryTarget + Range linux.NFNATRange +} + +const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange + +type nfNATTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *nfNATTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*nfNATTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) + nt := nfNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: nfNATMarhsalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, + }, + } + copy(nt.Target.Name[:], stack.RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.Addr) + copy(nt.Range.MaxAddr[:], rt.Addr) + + nt.Range.MinProto = htons(rt.Port) + nt.Range.MaxProto = nt.Range.MinProto + + ret := make([]byte, 0, nfNATMarhsalledSize) + return binary.Marshal(ret, usermem.ByteOrder, nt) +} + +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if size := nfNATMarhsalledSize; len(buf) < size { + nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("nfNATTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] + binary.Unmarshal(buf, usermem.ByteOrder, &natRange) + + // We don't support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("nfNATTargetMaker: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("nfNATTargetMaker: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/3549): Check for other flags. + // For now, redirect target only supports destination change. + if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + } + + return &target, nil +} + +// translateToStandardTarget translates from the value in a +// linux.XTStandardTarget to an stack.Verdict. +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { + // TODO(gvisor.dev/issue/170): Support other verdicts. + switch val { + case -linux.NF_ACCEPT - 1: + return &stack.AcceptTarget{NetworkProtocol: netProto}, nil + case -linux.NF_DROP - 1: + return &stack.DropTarget{NetworkProtocol: netProto}, nil + case -linux.NF_QUEUE - 1: + nflog("unsupported iptables verdict QUEUE") + return nil, syserr.ErrInvalidArgument + case linux.NF_RETURN: + return &stack.ReturnTarget{NetworkProtocol: netProto}, nil + default: + nflog("unknown iptables verdict %d", val) + return nil, syserr.ErrInvalidArgument + } +} + +// parseTarget parses a target from optVal. optVal should contain only the +// target. +func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.Target, *syserr.Error) { + nflog("set entries: parsing target of size %d", len(optVal)) + if len(optVal) < linux.SizeOfXTEntryTarget { + nflog("optVal has insufficient size for entry target %d", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + var target linux.XTEntryTarget + buf := optVal[:linux.SizeOfXTEntryTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + + return unmarshalTarget(target, filter, optVal) +} + // JumpTarget implements stack.Target. type JumpTarget struct { // Offset is the byte offset of the rule to jump to. It is used for @@ -27,9 +376,31 @@ type JumpTarget struct { // RuleNum is the rule to jump to. RuleNum int + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (jt *JumpTarget) ID() stack.TargetID { + return stack.TargetID{ + NetworkProtocol: jt.NetworkProtocol, + } } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } + +func ntohs(port uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, port) + return usermem.ByteOrder.Uint16(buf) +} + +func htons(port uint16) uint16 { + buf := make([]byte, 2) + usermem.ByteOrder.PutUint16(buf, port) + return binary.BigEndian.Uint16(buf) +} diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 4f98ee2d5..844acfede 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,21 +97,37 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the stack.Check codepath as matchers are added. + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return false, false - } + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber { + return false, false } + + default: + // We don't know the network protocol. return false, false } - tcpHeader := header.TCP(pkt.TransportHeader) + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { // There's no valid TCP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3f20fc891..63201201c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,23 +94,37 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. - if netHeader.TransportProtocol() != header.UDPProtocolNumber { - return false, false - } + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false } + + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + default: + // We don't know the network protocol. return false, false } - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) if len(udpHeader) < header.UDPMinimumSize { // There's no valid UDP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index d5ca3ac56..1f926aa91 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -16,6 +16,8 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 0d45e5053..31e374833 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -97,7 +97,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int } d := socket.NewDirent(t, netlinkSocketDevice) - defer d.DecRef() + defer d.DecRef(t) return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil } diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index bb205be0d..e8930f031 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -52,6 +52,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..5ddcd4be5 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -21,6 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -138,14 +140,14 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } @@ -162,9 +164,9 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { - s.connection.Release() - s.ep.Close() +func (s *socketOpsCommon) Release(ctx context.Context) { + s.connection.Release(ctx) + s.ep.Close(ctx) if s.bound { s.ports.Release(s.protocol.Protocol(), s.portID) @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -617,7 +621,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -644,7 +648,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index dbcd8b49a..c83b23242 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -57,14 +57,14 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } @@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV return fd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ea6ebd0e2..fae3b6783 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -22,6 +22,8 @@ go_library( "//pkg/binary", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e7d2c83d7..87e30d742 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -39,6 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -155,6 +158,9 @@ var Metrics = tcpip.Stats{ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."), + IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."), + IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), @@ -232,7 +238,7 @@ type commonEndpoint interface { // SetSockOpt implements tcpip.Endpoint.SetSockOpt and // transport.Endpoint.SetSockOpt. - SetSockOpt(interface{}) *tcpip.Error + SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and // transport.Endpoint.SetSockOptBool. @@ -244,7 +250,7 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. - GetSockOpt(interface{}) *tcpip.Error + GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and // transport.Endpoint.GetSockOpt. @@ -253,6 +259,9 @@ type commonEndpoint interface { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + + // LastError implements tcpip.Endpoint.LastError. + LastError() *tcpip.Error } // LINT.IfChange @@ -296,8 +305,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -325,7 +335,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue } dirent := socket.NewDirent(t, netstackDevice) - defer dirent.DecRef() + defer dirent.DecRef(t) return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ socketOpsCommon: socketOpsCommon{ Queue: queue, @@ -418,7 +428,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -446,8 +456,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} + + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error - v, cms, err := s.Endpoint.Read(&s.sender) + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -461,8 +484,35 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(ctx context.Context) { + e, ch := waiter.NewChannelEntry(nil) + s.EventRegister(&e, waiter.EventHUp|waiter.EventErr) + defer s.EventUnregister(&e) + s.Endpoint.Close() + + // SO_LINGER option is valid only for TCP. For other socket types + // return after endpoint close. + if family, skType, _ := s.Type(); skType != linux.SOCK_STREAM || (family != linux.AF_INET && family != linux.AF_INET6) { + return + } + + var v tcpip.LingerOption + if err := s.Endpoint.GetSockOpt(&v); err != nil { + return + } + + // The case for zero timeout is handled in tcp endpoint close function. + // Close is blocked until either: + // 1. The endpoint state is not in any of the states: FIN-WAIT1, + // CLOSING and LAST_ACK. + // 2. Timeout is reached. + if v.Enabled && v.Timeout != 0 { + t := kernel.TaskFromContext(ctx) + start := t.Kernel().MonotonicClock().Now() + deadline := start.Add(v.Timeout) + t.BlockWithDeadline(ch, true, deadline) + } } // Read implements fs.FileOperations.Read. @@ -785,7 +835,20 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Issue the bind request to the endpoint. - return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) + err := s.Endpoint.Bind(addr) + if err == tcpip.ErrNoPortAvailable { + // Bind always returns EADDRINUSE irrespective of if the specified port was + // already bound or if an ephemeral port was requested but none were + // available. + // + // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because + // UDP connect returns EAGAIN on ephemeral port exhaustion. + // + // TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion. + err = tcpip.ErrPortInUse + } + + return syserr.TranslateNetstackError(err) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -796,7 +859,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -805,7 +868,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock { + if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { return ep, wq, syserr.TranslateNetstackError(err) } @@ -818,15 +881,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -836,7 +902,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -846,13 +912,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -894,7 +955,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -904,68 +965,33 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil - } - - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_GET_INFO: - if outLen < linux.SizeOfIPTGetinfo { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) - if err != nil { - return nil, err - } - return info, nil - - case linux.IPT_SO_GET_ENTRIES: - if outLen < linux.SizeOfIPTGetEntries { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) - if err != nil { - return nil, err - } - return entries, nil - - } + return &val, nil } - return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen) } // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -974,10 +1000,10 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptTCP(t, ep, name, outLen) case linux.SOL_IPV6: - return getSockOptIPv6(t, ep, name, outLen) + return getSockOptIPv6(t, s, ep, name, outPtr, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen, family) + return getSockOptIP(t, s, ep, name, outPtr, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -998,7 +1024,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1007,11 +1033,14 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.GetSockOpt(tcpip.ErrorOption{}) + err := ep.LastError() if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1019,11 +1048,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1034,7 +1064,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1050,7 +1082,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1066,7 +1099,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1077,7 +1111,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1088,7 +1123,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1096,7 +1133,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1111,7 +1149,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1122,7 +1162,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1133,13 +1175,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + var v tcpip.LingerOption + var linger linux.Linger + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + if v.Enabled { + linger.OnOff = 1 + } + linger.Linger = int32(v.Timeout.Seconds()) + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1147,7 +1202,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1155,7 +1211,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1167,7 +1224,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1178,7 +1236,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1187,7 +1246,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1198,7 +1257,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1209,7 +1270,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1220,7 +1283,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1231,8 +1296,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1243,8 +1308,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1255,8 +1320,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1267,8 +1332,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1279,8 +1344,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1293,12 +1358,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1328,7 +1394,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1339,8 +1407,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + var lingerTimeout primitive.Int32 + if v >= 0 { + lingerTimeout = primitive.Int32(time.Duration(v) / time.Second) + } else { + lingerTimeout = -1 + } + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1352,7 +1425,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1363,8 +1437,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1375,8 +1449,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1384,7 +1458,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1395,7 +1469,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1403,21 +1479,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1428,7 +1507,82 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.IP6T_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet6{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet6), nil + + case linux.IP6T_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true) + if err != nil { + return nil, err + } + return &info, nil + + case linux.IP6T_SO_GET_ENTRIES: + // IPTGetEntries is reused for IPv6. + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return &entries, nil + + case linux.IP6T_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil default: emitUnimplementedEventIPv6(t, name) @@ -1437,7 +1591,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1450,11 +1604,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1466,7 +1621,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1480,7 +1636,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1491,21 +1647,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1516,7 +1677,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1527,7 +1690,82 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil + + case linux.IPT_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false) + if err != nil { + return nil, err + } + return &info, nil + + case linux.IPT_SO_GET_ENTRIES: + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return &entries, nil + + case linux.IPT_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil default: emitUnimplementedEventIP(t, name) @@ -1562,26 +1800,6 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return nil } - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_SET_REPLACE: - if len(optVal) < linux.SizeOfIPTReplace { - return syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return syserr.ErrNoDevice - } - // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal) - - case linux.IPT_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. - return nil - } - } - return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } @@ -1596,21 +1814,26 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptTCP(t, ep, name, optVal) case linux.SOL_IPV6: - return setSockOptIPv6(t, ep, name, optVal) + return setSockOptIPv6(t, s, ep, name, optVal) case linux.SOL_IP: - return setSockOptIP(t, ep, name, optVal) + return setSockOptIP(t, s, ep, name, optVal) + + case linux.SOL_PACKET: + // gVisor doesn't support any SOL_PACKET options just return not + // supported. Returning nil here will result in tcpdump thinking AF_PACKET + // features are supported and proceed to use them and break. + t.Kernel().EmitUnimplementedEvent(t) + return syserr.ErrProtocolNotAvailable case linux.SOL_UDP, linux.SOL_ICMPV6, - linux.SOL_RAW, - linux.SOL_PACKET: + linux.SOL_RAW: t.Kernel().EmitUnimplementedEvent(t) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. @@ -1655,7 +1878,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + v := tcpip.BindToDeviceOption(0) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) } s := t.NetworkContext() if s == nil { @@ -1663,7 +1887,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + v := tcpip.BindToDeviceOption(nicID) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) } } return syserr.ErrUnknownDevice @@ -1729,7 +1954,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + opt := tcpip.OutOfBandInlineOption(v) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1751,14 +1977,21 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return nil + return syserr.TranslateNetstackError( + ep.SetSockOpt(&tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger)})) + + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) default: socket.SetSockOptEmitUnimplementedEvent(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. @@ -1805,7 +2038,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 1 || v > linux.MAX_TCP_KEEPIDLE { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)))) + opt := tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_KEEPINTVL: if len(optVal) < sizeOfInt32 { @@ -1816,7 +2050,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 1 || v > linux.MAX_TCP_KEEPINTVL { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + opt := tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_KEEPCNT: if len(optVal) < sizeOfInt32 { @@ -1838,11 +2073,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 0 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + opt := tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) - if err := ep.SetSockOpt(v); err != nil { + if err := ep.SetSockOpt(&v); err != nil { return syserr.TranslateNetstackError(err) } return nil @@ -1852,8 +2088,9 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + v := int32(usermem.ByteOrder.Uint32(optVal)) + opt := tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_DEFER_ACCEPT: if len(optVal) < sizeOfInt32 { @@ -1863,7 +2100,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 0 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + opt := tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_SYNCNT: if len(optVal) < sizeOfInt32 { @@ -1888,12 +2126,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * emitUnimplementedEventTCP(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. -func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { @@ -1942,12 +2179,32 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + case linux.IP6T_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIP6TReplace { + return syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true) + + case linux.IP6T_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + default: emitUnimplementedEventIPv6(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } var ( @@ -2002,7 +2259,7 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { } // setSockOptIP implements SetSockOpt when level is SOL_IP. -func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2025,7 +2282,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{ NIC: tcpip.NICID(req.InterfaceIndex), // TODO(igudger): Change AddMembership to use the standard // any address representation. @@ -2039,7 +2296,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{ NIC: tcpip.NICID(req.InterfaceIndex), // TODO(igudger): Change DropMembership to use the standard // any address representation. @@ -2053,7 +2310,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), })) @@ -2112,13 +2369,43 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + + case linux.IPT_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false) + + case linux.IPT_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -2147,8 +2434,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // emitUnimplementedEventTCP emits unimplemented event if name is valid. This @@ -2333,7 +2619,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -2439,6 +2725,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2494,6 +2797,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2728,11 +3036,16 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2740,9 +3053,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2761,9 +3072,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2772,52 +3082,49 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2830,9 +3137,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2846,9 +3152,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -2874,7 +3179,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2897,21 +3202,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2920,7 +3232,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2931,32 +3243,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2973,6 +3285,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument @@ -2982,7 +3302,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -3019,9 +3339,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index d65a89316..4c6791fff 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -18,13 +18,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" - "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -56,6 +56,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) s := &SocketVFS2{ socketOpsCommon: socketOpsCommon{ @@ -78,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) @@ -150,14 +158,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs // tcpip.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -167,7 +179,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil { return 0, nil, 0, syserr.FromError(err) @@ -175,13 +187,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { + if peerAddr != nil { // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ @@ -200,7 +208,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,63 +218,28 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil - } - - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_GET_INFO: - if outLen < linux.SizeOfIPTGetinfo { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) - if err != nil { - return nil, err - } - return info, nil - - case linux.IPT_SO_GET_ENTRIES: - if outLen < linux.SizeOfIPTGetEntries { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) - if err != nil { - return nil, err - } - return entries, nil - - } + return &val, nil } - return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen) } // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by @@ -296,26 +269,6 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return nil } - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_SET_REPLACE: - if len(optVal) < linux.SizeOfIPTReplace { - return syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return syserr.ErrNoDevice - } - // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal) - - case linux.IPT_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. - return nil - } - } - return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 548442b96..1028d2a6e 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,6 +15,8 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } @@ -143,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcp.ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -154,17 +166,17 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcp.ReceiveBufferSizeOption{ + rs := tcpip.TCPReceiveBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &rs)).ToError() } // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcp.SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -175,24 +187,40 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcp.SendBufferSizeOption{ + ss := tcpip.TCPSendBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &ss)).ToError() } // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcp.SACKEnabled + var sack tcpip.TCPSACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() + opt := tcpip.TCPSACKEnabled(enabled) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() +} + +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + var recovery tcpip.TCPRecovery + if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + return inet.TCPLossRecovery(recovery), nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + opt := tcpip.TCPRecovery(recovery) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // Statistics implements inet.Stack.Statistics. @@ -384,3 +412,24 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { s.Stack.RestoreCleanupEndpoints(es) } + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + return s.Stack.Forwarding(protocol) + default: + panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol)) + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + s.Stack.SetForwarding(protocol, enable) + default: + panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + } + return nil +} diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fcd7f9d7f..fd31479e5 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -45,8 +46,8 @@ type ControlMessages struct { } // Release releases Unix domain socket credentials and rights. -func (c *ControlMessages) Release() { - c.Unix.Release() +func (c *ControlMessages) Release(ctx context.Context) { + c.Unix.Release(ctx) } // Socket is an interface combining fs.FileOperations and SocketOps, @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cca5e70f1..cc7408698 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,12 +1,37 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "socket_refs", + out = "socket_refs.go", + package = "unix", + prefix = "socketOperations", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "SocketOperations", + }, +) + +go_template_instance( + name = "socket_vfs2_refs", + out = "socket_vfs2_refs.go", + package = "unix", + prefix = "socketVFS2", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "SocketVFS2", + }, +) + go_library( name = "unix", srcs = [ "device.go", "io.go", + "socket_refs.go", + "socket_vfs2_refs.go", "unix.go", "unix_vfs2.go", ], @@ -15,6 +40,8 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/log", + "//pkg/marshal", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index c708b6030..26c3a51b9 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -15,6 +15,17 @@ go_template_instance( }, ) +go_template_instance( + name = "queue_refs", + out = "queue_refs.go", + package = "transport", + prefix = "queue", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "queue", + }, +) + go_library( name = "transport", srcs = [ @@ -22,6 +33,7 @@ go_library( "connectioned_state.go", "connectionless.go", "queue.go", + "queue_refs.go", "transport_message_list.go", "unix.go", ], diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index a1e49cc57..aa4f3c04d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -142,9 +142,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E } q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} - q1.EnableLeakCheck("transport.queue") + q1.EnableLeakCheck() q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - q2.EnableLeakCheck("transport.queue") + q2.EnableLeakCheck() if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} @@ -211,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool { // The socket will be a fresh state after a call to close and may be reused. // That is, close may be used to "unbind" or "disconnect" the socket in error // paths. -func (e *connectionedEndpoint) Close() { +func (e *connectionedEndpoint) Close(ctx context.Context) { e.Lock() var c ConnectedEndpoint var r Receiver @@ -233,7 +233,7 @@ func (e *connectionedEndpoint) Close() { case e.Listening(): close(e.acceptedChan) for n := range e.acceptedChan { - n.Close() + n.Close(ctx) } e.acceptedChan = nil e.path = "" @@ -241,11 +241,11 @@ func (e *connectionedEndpoint) Close() { e.Unlock() if c != nil { c.CloseNotify() - c.Release() + c.Release(ctx) } if r != nil { r.CloseNotify() - r.Release() + r.Release(ctx) } } @@ -300,14 +300,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn } readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} - readQueue.EnableLeakCheck("transport.queue") + readQueue.EnableLeakCheck() ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - writeQueue.EnableLeakCheck("transport.queue") + writeQueue.EnableLeakCheck() if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { @@ -340,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: // Busy; return ECONNREFUSED per spec. - ne.Close() + ne.Close(ctx) e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused @@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { } // Accept accepts a new connection. -func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { +func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() defer e.Unlock() @@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { select { case ne := <-e.acceptedChan: + if peerAddr != nil { + ne.Lock() + c := ne.connected + ne.Unlock() + if c != nil { + addr, err := c.GetLocalAddress() + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + *peerAddr = addr + } + } return ne, nil default: diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 4b06d63ac..f8aacca13 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -42,7 +42,7 @@ var ( func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} - q.EnableLeakCheck("transport.queue") + q.EnableLeakCheck() ep.receiver = &queueReceiver{readQueue: &q} return ep } @@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool { // Close puts the endpoint in a closed state and frees all resources associated // with it. -func (e *connectionlessEndpoint) Close() { +func (e *connectionlessEndpoint) Close(ctx context.Context) { e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) e.connected = nil } @@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() { e.Unlock() r.CloseNotify() - r.Release() + r.Release(ctx) } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. @@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C if err != nil { return 0, syserr.ErrInvalidEndpointState } - defer connected.Release() + defer connected.Release(ctx) e.Lock() - n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) } e.connected = connected e.Unlock() @@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi } // Listen starts listening on the connection. -func (e *connectionlessEndpoint) Listen(int) *syserr.Error { +func (*connectionlessEndpoint) Listen(int) *syserr.Error { return syserr.ErrNotSupported } // Accept accepts a new connection. -func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) { +func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) { return nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index d8f3ad63d..342def28f 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,7 +15,7 @@ package transport import ( - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" @@ -27,7 +27,7 @@ import ( // // +stateify savable type queue struct { - refs.AtomicRefCount + queueRefs ReaderQueue *waiter.Queue WriterQueue *waiter.Queue @@ -57,21 +57,23 @@ func (q *queue) Close() { // Both the read and write queues must be notified after resetting: // q.ReaderQueue.Notify(waiter.EventIn) // q.WriterQueue.Notify(waiter.EventOut) -func (q *queue) Reset() { +func (q *queue) Reset(ctx context.Context) { q.mu.Lock() for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release() + cur.Release(ctx) } q.dataList.Reset() q.used = 0 q.mu.Unlock() } -// DecRef implements RefCounter.DecRef with destructor q.Reset. -func (q *queue) DecRef() { - q.DecRefWithDestructor(q.Reset) - // We don't need to notify after resetting because no one cares about - // this queue after all references have been dropped. +// DecRef implements RefCounter.DecRef. +func (q *queue) DecRef(ctx context.Context) { + q.queueRefs.DecRef(func() { + // We don't need to notify after resetting because no one cares about + // this queue after all references have been dropped. + q.Reset(ctx) + }) } // IsReadable determines if q is currently readable. @@ -111,7 +113,7 @@ func (q *queue) IsWritable() bool { // // If notify is true, ReaderQueue.Notify must be called: // q.ReaderQueue.Notify(waiter.EventIn) -func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { +func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { q.mu.Lock() if q.closed { @@ -124,7 +126,7 @@ func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress } if discardEmpty && l == 0 { q.mu.Unlock() - c.Release() + c.Release(ctx) return 0, false, nil } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2f1b127df..d6fc03520 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -37,7 +37,7 @@ type RightsControlMessage interface { Clone() RightsControlMessage // Release releases any resources owned by the RightsControlMessage. - Release() + Release(ctx context.Context) } // A CredentialsControlMessage is a control message containing Unix credentials. @@ -74,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages { } // Release releases both the credentials and the rights. -func (c *ControlMessages) Release() { +func (c *ControlMessages) Release(ctx context.Context) { if c.Rights != nil { - c.Rights.Release() + c.Rights.Release(ctx) } *c = ControlMessages{} } @@ -90,7 +90,7 @@ type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources // associated with it. - Close() + Close(ctx context.Context) // RecvMsg reads data and a control message from the endpoint. This method // does not block if there is no data pending. @@ -151,7 +151,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *syserr.Error) + // + // peerAddr if not nil will be populated with the address of the connected + // peer on a successful accept. + Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -172,9 +175,8 @@ type Endpoint interface { // connected. GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) - // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option - // types. - SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOpt sets a socket option. + SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error // SetSockOptBool sets a socket option for simple cases when a value has // the int type. @@ -184,9 +186,8 @@ type Endpoint interface { // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error - // GetSockOpt gets a socket option. opt should be a pointer to one of the - // tcpip.*Option types. - GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOpt gets a socket option. + GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error // GetSockOptBool gets a socket option for simple cases when a return // value has the int type. @@ -199,6 +200,9 @@ type Endpoint interface { // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 + + // LastError implements tcpip.Endpoint.LastError. + LastError() *tcpip.Error } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket @@ -252,7 +256,7 @@ type BoundEndpoint interface { // Release releases any resources held by the BoundEndpoint. It must be // called before dropping all references to a BoundEndpoint returned by a // function. - Release() + Release(ctx context.Context) } // message represents a message passed over a Unix domain socket. @@ -281,8 +285,8 @@ func (m *message) Length() int64 { } // Release releases any resources held by the message. -func (m *message) Release() { - m.Control.Release() +func (m *message) Release(ctx context.Context) { + m.Control.Release(ctx) } // Peek returns a copy of the message. @@ -304,7 +308,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -333,7 +337,7 @@ type Receiver interface { // Release releases any resources owned by the Receiver. It should be // called before droping all references to a Receiver. - Release() + Release(ctx context.Context) } // queueReceiver implements Receiver for datagram sockets. @@ -344,7 +348,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -398,8 +402,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 { } // Release implements Receiver.Release. -func (q *queueReceiver) Release() { - q.readQueue.DecRef() +func (q *queueReceiver) Release(ctx context.Context) { + q.readQueue.DecRef(ctx) } // streamQueueReceiver implements Receiver for stream sockets. @@ -456,7 +460,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -502,7 +506,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, var cmTruncated bool if c.Rights != nil && numRights == 0 { - c.Rights.Release() + c.Rights.Release(ctx) c.Rights = nil cmTruncated = true } @@ -557,7 +561,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, // Consume rights. if numRights == 0 { cmTruncated = true - q.control.Rights.Release() + q.control.Rights.Release(ctx) } else { c.Rights = q.control.Rights haveRights = true @@ -582,7 +586,7 @@ type ConnectedEndpoint interface { // // syserr.ErrWouldBlock can be returned along with a partial write if // the caller should block to send the rest of the data. - Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) + Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) // SendNotify notifies the ConnectedEndpoint of a successful Send. This // must not be called while holding any endpoint locks. @@ -616,7 +620,7 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. - Release() + Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. @@ -654,7 +658,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) } // Send implements ConnectedEndpoint.Send. -func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { discardEmpty := false truncate := false if e.endpoint.Type() == linux.SOCK_STREAM { @@ -669,7 +673,7 @@ func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.Fu truncate = true } - return e.writeQueue.Enqueue(data, c, from, discardEmpty, truncate) + return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate) } // SendNotify implements ConnectedEndpoint.SendNotify. @@ -707,8 +711,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 { } // Release implements ConnectedEndpoint.Release. -func (e *connectedEndpoint) Release() { - e.writeQueue.DecRef() +func (e *connectedEndpoint) Release(ctx context.Context) { + e.writeQueue.DecRef(ctx) } // CloseUnread implements ConnectedEndpoint.CloseUnread. @@ -742,6 +746,9 @@ type baseEndpoint struct { // path is not empty if the endpoint has been bound, // or may be used if the endpoint is connected. path string + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // EventRegister implements waiter.Waitable.EventRegister. @@ -798,7 +805,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek) e.Unlock() if err != nil { return 0, 0, ControlMessages{}, false, err @@ -827,7 +834,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return 0, syserr.ErrAlreadyConnected } - n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -837,8 +844,14 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return n, err } -// SetSockOpt sets a socket option. Currently not supported. -func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { +// SetSockOpt sets a socket option. +func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.LingerOption: + e.Lock() + e.linger = *v + e.Unlock() + } return nil } @@ -940,9 +953,12 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: +func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.Lock() + *o = e.linger + e.Unlock() return nil default: @@ -951,6 +967,11 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } +// LastError implements Endpoint.LastError. +func (*baseEndpoint) LastError() *tcpip.Error { + return nil +} + // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { @@ -1001,6 +1022,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } // Release implements BoundEndpoint.Release. -func (*baseEndpoint) Release() { +func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..f80011ce4 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -55,13 +55,14 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + socketOperationsRefs socketOpsCommon } // New creates a new unix socket. func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true}) } @@ -79,34 +80,42 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty stype: stype, }, } - s.EnableLeakCheck("unix.SocketOperations") + s.EnableLeakCheck() return fs.NewFile(ctx, d, flags, &s) } +// DecRef implements RefCounter.DecRef. +func (s *SocketOperations) DecRef(ctx context.Context) { + s.socketOperationsRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implemements fs.FileOperations.Release. +func (s *SocketOperations) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // socketOpsCommon contains the socket operations common to VFS1 and VFS2. // // +stateify savable type socketOpsCommon struct { - refs.AtomicRefCount socket.SendReceiveTimeout ep transport.Endpoint stype linux.SockType -} -// DecRef implements RefCounter.DecRef. -func (s *socketOpsCommon) DecRef() { - s.DecRefWithDestructor(func() { - s.ep.Close() - }) -} - -// Release implemements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { - // Release only decrements a reference on s because s may be referenced in - // the abstract socket namespace. - s.DecRef() + // abstractName and abstractNamespace indicate the name and namespace of the + // socket if it is bound to an abstract socket namespace. Once the socket is + // bound, they cannot be modified. + abstractName string + abstractNamespace *kernel.AbstractSocketNamespace } func (s *socketOpsCommon) isPacket() bool { @@ -184,8 +193,8 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -196,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -205,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -218,22 +227,25 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } } ns := New(t, ep, s.stype) - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -243,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -283,17 +290,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } + s.abstractName = name + s.abstractNamespace = asn } else { // The parent and name. var d *fs.Dirent var name string cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) // Is there no slash at all? if !strings.Contains(p, "/") { @@ -301,7 +312,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { name = p } else { root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Find the last path component, we know that something follows // that final slash, otherwise extractPath() would have failed. lastSlash := strings.LastIndex(p, "/") @@ -317,7 +328,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // No path available. return syserr.ErrNoSuchFile } - defer d.DecRef() + defer d.DecRef(t) name = p[lastSlash+1:] } @@ -331,7 +342,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if err != nil { return syserr.ErrPortInUse } - childDir.DecRef() + childDir.DecRef(t) } return nil @@ -377,9 +388,9 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, FollowFinalSymlink: true, } ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) - root.DecRef() + root.DecRef(t) if relPath { - start.DecRef() + start.DecRef(t) } if e != nil { return nil, syserr.FromError(e) @@ -392,15 +403,15 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, cwd := t.FSContext().WorkingDirectory() remainingTraversals := uint(fs.DefaultTraversalLimit) d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals) - cwd.DecRef() - root.DecRef() + cwd.DecRef(t) + root.DecRef(t) if e != nil { return nil, syserr.FromError(e) } // Extract the endpoint if one is there. ep := d.Inode.BoundEndpoint(path) - d.DecRef() + d.DecRef(t) if ep == nil { // No socket! return nil, syserr.ErrConnectionRefused @@ -414,7 +425,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool if err != nil { return err } - defer ep.Release() + defer ep.Release(t) // Connect the server endpoint. err = s.ep.Connect(t, ep) @@ -472,7 +483,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if err != nil { return 0, err } - defer ep.Release() + defer ep.Release(t) w.To = ep if ep.Passcred() && w.Control.Credentials == nil { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index ff2149250..3345124cc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" @@ -36,12 +37,15 @@ import ( // SocketVFS2 implements socket.SocketVFS2 (and by extension, // vfs.FileDescriptionImpl) for Unix sockets. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl vfs.LockFD + socketVFS2Refs socketOpsCommon } @@ -52,6 +56,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) if err != nil { @@ -87,15 +92,34 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 return vfsfd, nil } +// DecRef implements RefCounter.DecRef. +func (s *SocketVFS2) DecRef(ctx context.Context) { + s.socketVFS2Refs.DecRef(func() { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen) } // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.socketOpsCommon.EventRegister(&e, waiter.EventIn) @@ -104,7 +128,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -117,15 +141,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -135,7 +162,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) @@ -143,13 +170,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ @@ -182,19 +204,23 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } + s.abstractName = name + s.abstractNamespace = asn } else { path := fspath.Parse(p) root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root relPath := !path.Absolute if relPath { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } pop := vfs.PathOperation{ Root: root, @@ -332,7 +358,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) f, err := NewSockfsFile(t, ep, stype) if err != nil { - ep.Close() + ep.Close(t) return nil, err } return f, nil @@ -356,14 +382,14 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (* ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) s1, err := NewSockfsFile(t, ep1, stype) if err != nil { - ep1.Close() - ep2.Close() + ep1.Close(t) + ep2.Close(t) return nil, nil, err } s2, err := NewSockfsFile(t, ep2, stype) if err != nil { - s1.DecRef() - ep2.Close() + s1.DecRef(t) + ep2.Close(t) return nil, nil, err } diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 9eb626b76..245d2c5cf 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -60,8 +60,11 @@ type SaveOpts struct { func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() - defer k.Unpause() - defer log.Infof("Tasks resumed after save.") + k.ReceiveTaskStates() + defer func() { + k.Unpause() + log.Infof("Tasks resumed after save.") + }() w.Stop() defer w.Start() diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 88d5db9fc..a920180d3 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", + "//pkg/marshal/primitive", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go index a6e48b836..ae3b998c8 100644 --- a/pkg/sentry/strace/epoll.go +++ b/pkg/sentry/strace/epoll.go @@ -26,7 +26,7 @@ import ( func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { var e linux.EpollEvent - if _, err := t.CopyIn(eventAddr, &e); err != nil { + if _, err := e.CopyIn(t, eventAddr); err != nil { return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err) } var sb strings.Builder @@ -41,7 +41,7 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui addr := eventsAddr for i := uint64(0); i < numEvents; i++ { var e linux.EpollEvent - if _, err := t.CopyIn(addr, &e); err != nil { + if _, err := e.CopyIn(t, addr); err != nil { fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err) continue } @@ -50,10 +50,10 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui sb.WriteString("...") break } - if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok { - fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr) - continue - } + // Allowing addr to overflow is consistent with Linux, and harmless; if + // this isn't the last iteration of the loop, the next call to CopyIn + // will just fail with EFAULT. + addr, _ = addr.AddLength(uint64(linux.SizeOfEpollEvent)) } sb.WriteString("}") return sb.String() @@ -75,7 +75,7 @@ var epollEventEvents = abi.FlagSet{ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"}, {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"}, {Flag: linux.EPOLLERR, Name: "EPOLLERR"}, - {Flag: linux.EPOLLHUP, Name: "EPULLHUP"}, + {Flag: linux.EPOLLHUP, Name: "EPOLLHUP"}, {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"}, {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"}, {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"}, diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index c0512de89..cc5f70cd4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" @@ -166,7 +167,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) } buf := make([]byte, length) - if _, err := t.CopyIn(addr, &buf); err != nil { + if _, err := t.CopyInBytes(addr, buf); err != nil { return fmt.Sprintf("%#x (error decoding control: %v)", addr, err) } @@ -302,7 +303,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string { var msg slinux.MessageHeader64 - if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil { + if _, err := msg.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err) } s := fmt.Sprintf( @@ -380,9 +381,9 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) { // socklen_t is 32-bits. - var l uint32 - _, err := t.CopyIn(addr, &l) - return l, err + var l primitive.Uint32 + _, err := l.CopyIn(t, addr) + return uint32(l), err } func sockLenPointer(t *kernel.Task, addr usermem.Addr) string { @@ -436,22 +437,22 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string { switch optLen { case 1: - var v uint8 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint8 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } return fmt.Sprintf("%#x {value=%v}", optVal, v) case 2: - var v uint16 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint16 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } return fmt.Sprintf("%#x {value=%v}", optVal, v) case 4: - var v uint32 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint32 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } @@ -521,6 +522,7 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT", linux.IP_PKTOPTIONS: "IP_PKTOPTIONS", linux.IP_MTU: "IP_MTU", + linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST", }, linux.SOL_SOCKET: { linux.SO_ERROR: "SO_ERROR", @@ -631,6 +633,8 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.IPV6_UNICAST_IF: "IPV6_UNICAST_IF", linux.MCAST_MSFILTER: "MCAST_MSFILTER", linux.IPV6_ADDRFORM: "IPV6_ADDRFORM", + linux.IP6T_SO_GET_INFO: "IP6T_SO_GET_INFO", + linux.IP6T_SO_GET_ENTRIES: "IP6T_SO_GET_ENTRIES", }, linux.SOL_NETLINK: { linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR", diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 68ca537c8..396744597 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -17,17 +17,16 @@ package strace import ( - "encoding/binary" "fmt" "strconv" "strings" - "syscall" "time" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/eventchannel" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -91,7 +90,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma } b := make([]byte, size) - amt, err := t.CopyIn(ar.Start, b) + amt, err := t.CopyInBytes(ar.Start, b) if err != nil { iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err) continue @@ -118,7 +117,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st } b := make([]byte, size) - amt, err := t.CopyIn(addr, b) + amt, err := t.CopyInBytes(addr, b) if err != nil { return fmt.Sprintf("%#x (error decoding string: %s)", addr, err) } @@ -147,14 +146,14 @@ func fd(t *kernel.Task, fd int32) string { root := t.FSContext().RootDirectory() if root != nil { - defer root.DecRef() + defer root.DecRef(t) } if fd == linux.AT_FDCWD { wd := t.FSContext().WorkingDirectory() var name string if wd != nil { - defer wd.DecRef() + defer wd.DecRef(t) name, _ = wd.FullName(root) } else { name = "(unknown cwd)" @@ -167,7 +166,7 @@ func fd(t *kernel.Task, fd int32) string { // Cast FD to uint64 to avoid printing negative hex. return fmt.Sprintf("%#x (bad FD)", uint64(fd)) } - defer file.DecRef() + defer file.DecRef(t) name, _ := file.Dirent.FullName(root) return fmt.Sprintf("%#x %s", fd, name) @@ -175,12 +174,12 @@ func fd(t *kernel.Task, fd int32) string { func fdVFS2(t *kernel.Task, fd int32) string { root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) vfsObj := root.Mount().Filesystem().VirtualFilesystem() if fd == linux.AT_FDCWD { wd := t.FSContext().WorkingDirectoryVFS2() - defer wd.DecRef() + defer wd.DecRef(t) name, _ := vfsObj.PathnameWithDeleted(t, root, wd) return fmt.Sprintf("AT_FDCWD %s", name) @@ -191,7 +190,7 @@ func fdVFS2(t *kernel.Task, fd int32) string { // Cast FD to uint64 to avoid printing negative hex. return fmt.Sprintf("%#x (bad FD)", uint64(fd)) } - defer file.DecRef() + defer file.DecRef(t) name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry()) return fmt.Sprintf("%#x %s", fd, name) @@ -199,7 +198,7 @@ func fdVFS2(t *kernel.Task, fd int32) string { func fdpair(t *kernel.Task, addr usermem.Addr) string { var fds [2]int32 - _, err := t.CopyIn(addr, &fds) + _, err := primitive.CopyInt32SliceIn(t, addr, fds[:]) if err != nil { return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err) } @@ -209,7 +208,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string { func uname(t *kernel.Task, addr usermem.Addr) string { var u linux.UtsName - if _, err := t.CopyIn(addr, &u); err != nil { + if _, err := u.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err) } @@ -222,7 +221,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timespec - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err) } @@ -244,7 +243,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timespec - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err) } return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec) @@ -256,7 +255,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timeval - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err) } @@ -268,8 +267,8 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string { return "null" } - var utim syscall.Utimbuf - if _, err := t.CopyIn(addr, &utim); err != nil { + var utim linux.Utime + if _, err := utim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err) } @@ -282,7 +281,7 @@ func stat(t *kernel.Task, addr usermem.Addr) string { } var stat linux.Stat - if _, err := t.CopyIn(addr, &stat); err != nil { + if _, err := stat.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err) } return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec)) @@ -294,7 +293,7 @@ func itimerval(t *kernel.Task, addr usermem.Addr) string { } interval := timeval(t, addr) - value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{}))) + value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } @@ -304,7 +303,7 @@ func itimerspec(t *kernel.Task, addr usermem.Addr) string { } interval := timespec(t, addr) - value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{}))) + value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } @@ -330,7 +329,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string { } var ru linux.Rusage - if _, err := t.CopyIn(addr, &ru); err != nil { + if _, err := ru.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err) } return fmt.Sprintf("%#x %+v", addr, ru) @@ -342,7 +341,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string { } var hdr linux.CapUserHeader - if _, err := t.CopyIn(addr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding header: %s)", addr, err) } @@ -367,7 +366,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { } var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err) } @@ -376,7 +375,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { switch hdr.Version { case linux.LINUX_CAPABILITY_VERSION_1: var data linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := data.CopyIn(t, dataAddr); err != nil { return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err) } p = uint64(data.Permitted) @@ -384,7 +383,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { e = uint64(data.Effective) case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3: var data [2]linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil { return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err) } p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32) diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index d9fb808c0..d23a0068a 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -28,7 +28,7 @@ import ( // CreateEpoll implements the epoll_create(2) linux syscall. func CreateEpoll(t *kernel.Task, closeOnExec bool) (int32, error) { file := epoll.NewEventPoll(t) - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: closeOnExec, @@ -47,14 +47,14 @@ func AddEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, mask if epollfile == nil { return syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Get the target file id. file := t.GetFile(fd) if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) @@ -73,14 +73,14 @@ func UpdateEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, m if epollfile == nil { return syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Get the target file id. file := t.GetFile(fd) if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) @@ -99,14 +99,14 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { if epollfile == nil { return syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Get the target file id. file := t.GetFile(fd) if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) @@ -115,7 +115,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { } // Try to remove the entry. - return e.RemoveEntry(epoll.FileIdentifier{file, fd}) + return e.RemoveEntry(t, epoll.FileIdentifier{file, fd}) } // WaitEpoll implements the epoll_wait(2) linux syscall. @@ -125,7 +125,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve if epollfile == nil { return nil, syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..75752b2e6 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -56,6 +56,7 @@ go_library( "sys_xattr.go", "timespec.go", ], + marshal = True, visibility = ["//:sandbox"], deps = [ "//pkg/abi", @@ -64,6 +65,8 @@ go_library( "//pkg/bpf", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/rand", "//pkg/safemem", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 64de56ac5..dab6207c0 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -36,8 +36,8 @@ var ( // errors, we may consume the error and return only the partial read/write. // // op and f are used only for panics. -func HandleIOErrorVFS2(t *kernel.Task, partialResult bool, err, intr error, op string, f *vfs.FileDescription) error { - known, err := handleIOErrorImpl(t, partialResult, err, intr, op) +func HandleIOErrorVFS2(t *kernel.Task, partialResult bool, ioerr, intr error, op string, f *vfs.FileDescription) error { + known, err := handleIOErrorImpl(t, partialResult, ioerr, intr, op) if err != nil { return err } @@ -46,7 +46,7 @@ func HandleIOErrorVFS2(t *kernel.Task, partialResult bool, err, intr error, op s fs := f.Mount().Filesystem().VirtualFilesystem() root := vfs.RootFromContext(t) name, _ := fs.PathnameWithDeleted(t, root, f.VirtualDentry()) - log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, err, err, op, name) + log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name) partialResultOnce.Do(partialResultMetric.Increment) } return nil @@ -56,15 +56,15 @@ func HandleIOErrorVFS2(t *kernel.Task, partialResult bool, err, intr error, op s // errors, we may consume the error and return only the partial read/write. // // op and f are used only for panics. -func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op string, f *fs.File) error { - known, err := handleIOErrorImpl(t, partialResult, err, intr, op) +func handleIOError(t *kernel.Task, partialResult bool, ioerr, intr error, op string, f *fs.File) error { + known, err := handleIOErrorImpl(t, partialResult, ioerr, intr, op) if err != nil { return err } if !known { // An unknown error is encountered with a partial read/write. name, _ := f.Dirent.FullName(nil /* ignore chroot */) - log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, err, err, op, name, f.FileOperations) + log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations) partialResultOnce.Do(partialResultMetric.Increment) } return nil @@ -147,7 +147,7 @@ func handleIOErrorImpl(t *kernel.Task, partialResult bool, err, intr error, op s } switch err.(type) { - case kernel.SyscallRestartErrno: + case syserror.SyscallRestartErrno: // Identical to the EINTR case. return true, nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ea4f9b1a7..b293669de 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -138,7 +138,7 @@ var AMD64 = &kernel.SyscallTable{ 83: syscalls.Supported("mkdir", Mkdir), 84: syscalls.Supported("rmdir", Rmdir), 85: syscalls.Supported("creat", Creat), - 86: syscalls.Supported("link", Link), + 86: syscalls.PartiallySupported("link", Link, "Limited support with Gofer. Link count and linked files may get out of sync because gVisor is not aware of external hardlinks.", nil), 87: syscalls.Supported("unlink", Unlink), 88: syscalls.Supported("symlink", Symlink), 89: syscalls.Supported("readlink", Readlink), @@ -305,9 +305,9 @@ var AMD64 = &kernel.SyscallTable{ 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), - 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), 257: syscalls.Supported("openat", Openat), 258: syscalls.Supported("mkdirat", Mkdirat), @@ -317,7 +317,7 @@ var AMD64 = &kernel.SyscallTable{ 262: syscalls.Supported("fstatat", Fstatat), 263: syscalls.Supported("unlinkat", Unlinkat), 264: syscalls.Supported("renameat", Renameat), - 265: syscalls.Supported("linkat", Linkat), + 265: syscalls.PartiallySupported("linkat", Linkat, "See link(2).", nil), 266: syscalls.Supported("symlinkat", Symlinkat), 267: syscalls.Supported("readlinkat", Readlinkat), 268: syscalls.Supported("fchmodat", Fchmodat), @@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{ 270: syscalls.Supported("pselect", Pselect), 271: syscalls.Supported("ppoll", Ppoll), 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 273: syscalls.Supported("set_robust_list", SetRobustList), + 274: syscalls.Supported("get_robust_list", GetRobustList), 275: syscalls.Supported("splice", Splice), 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), @@ -346,7 +346,7 @@ var AMD64 = &kernel.SyscallTable{ 291: syscalls.Supported("epoll_create1", EpollCreate1), 292: syscalls.Supported("dup3", Dup3), 293: syscalls.Supported("pipe2", Pipe2), - 294: syscalls.Supported("inotify_init1", InotifyInit1), + 294: syscalls.PartiallySupported("inotify_init1", InotifyInit1, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), 295: syscalls.Supported("preadv", Preadv), 296: syscalls.Supported("pwritev", Pwritev), 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), @@ -454,9 +454,9 @@ var ARM64 = &kernel.SyscallTable{ 23: syscalls.Supported("dup", Dup), 24: syscalls.Supported("dup3", Dup3), 25: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), - 26: syscalls.Supported("inotify_init1", InotifyInit1), - 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 26: syscalls.PartiallySupported("inotify_init1", InotifyInit1, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) @@ -527,8 +527,8 @@ var ARM64 = &kernel.SyscallTable{ 96: syscalls.Supported("set_tid_address", SetTidAddress), 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), - 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 99: syscalls.Supported("set_robust_list", SetRobustList), + 100: syscalls.Supported("get_robust_list", GetRobustList), 101: syscalls.Supported("nanosleep", Nanosleep), 102: syscalls.Supported("getitimer", Getitimer), 103: syscalls.Supported("setitimer", Setitimer), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index ba2557c52..0bf313a13 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -17,6 +17,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -36,7 +37,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // // The context pointer _must_ be zero initially. var idIn uint64 - if _, err := t.CopyIn(idAddr, &idIn); err != nil { + if _, err := primitive.CopyUint64In(t, idAddr, &idIn); err != nil { return 0, nil, err } if idIn != 0 { @@ -49,7 +50,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Copy out the new ID. - if _, err := t.CopyOut(idAddr, &id); err != nil { + if _, err := primitive.CopyUint64Out(t, idAddr, id); err != nil { t.MemoryManager().DestroyAIOContext(t, id) return 0, nil, err } @@ -142,7 +143,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S ev := v.(*linux.IOEvent) // Copy out the result. - if _, err := t.CopyOut(eventsAddr, ev); err != nil { + if _, err := ev.CopyOut(t, eventsAddr); err != nil { if count > 0 { return uintptr(count), nil, nil } @@ -247,7 +248,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu ev.Result = -int64(kernel.ExtractErrno(err, 0)) } - file.DecRef() + file.DecRef(ctx) // Queue the result for delivery. actx.FinishRequest(ev) @@ -257,7 +258,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu // wake up. if eventFile != nil { eventFile.FileOperations.(*eventfd.EventOperations).Signal(1) - eventFile.DecRef() + eventFile.DecRef(ctx) } } } @@ -269,7 +270,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user // File not found. return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Was there an eventFD? Extract it. var eventFile *fs.File @@ -279,7 +280,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user // Bad FD. return syserror.EBADF } - defer eventFile.DecRef() + defer eventFile.DecRef(t) // Check that it is an eventfd. if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok { @@ -338,21 +339,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } for i := int32(0); i < nrEvents; i++ { - // Copy in the address. - cbAddrNative := t.Arch().Native(0) - if _, err := t.CopyIn(addr, cbAddrNative); err != nil { - if i > 0 { - // Some successful. - return uintptr(i), nil, nil + // Copy in the callback address. + var cbAddr usermem.Addr + switch t.Arch().Width() { + case 8: + var cbAddrP primitive.Uint64 + if _, err := cbAddrP.CopyIn(t, addr); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err } - // Nothing done. - return 0, nil, err + cbAddr = usermem.Addr(cbAddrP) + default: + return 0, nil, syserror.ENOSYS } // Copy in this callback. var cb linux.IOCallback - cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) - if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if _, err := cb.CopyIn(t, cbAddr); err != nil { if i > 0 { // Some have been successful. diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go index adf5ea5f2..d3b85e11b 100644 --- a/pkg/sentry/syscalls/linux/sys_capability.go +++ b/pkg/sentry/syscalls/linux/sys_capability.go @@ -45,7 +45,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal dataAddr := args[1].Pointer() var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return 0, nil, err } // hdr.Pid doesn't need to be valid if this capget() is a "version probe" @@ -65,7 +65,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal Permitted: uint32(p), Inheritable: uint32(i), } - _, err = t.CopyOut(dataAddr, &data) + _, err = data.CopyOut(t, dataAddr) return 0, nil, err case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3: @@ -88,12 +88,12 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal Inheritable: uint32(i >> 32), }, } - _, err = t.CopyOut(dataAddr, &data) + _, err = linux.CopyCapUserDataSliceOut(t, dataAddr, data[:]) return 0, nil, err default: hdr.Version = linux.HighestCapabilityVersion - if _, err := t.CopyOut(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } if dataAddr != 0 { @@ -109,7 +109,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal dataAddr := args[1].Pointer() var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return 0, nil, err } switch hdr.Version { @@ -118,7 +118,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EPERM } var data linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := data.CopyIn(t, dataAddr); err != nil { return 0, nil, err } p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities @@ -131,7 +131,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EPERM } var data [2]linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil { return 0, nil, err } p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities @@ -141,7 +141,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal default: hdr.Version = linux.HighestCapabilityVersion - if _, err := t.CopyOut(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go index ed3413ca6..3b4f879e4 100644 --- a/pkg/sentry/syscalls/linux/sys_eventfd.go +++ b/pkg/sentry/syscalls/linux/sys_eventfd.go @@ -37,7 +37,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc event.SetFlags(fs.SettableFileFlags{ NonBlocking: flags&linux.EFD_NONBLOCK != 0, }) - defer event.DecRef() + defer event.DecRef(t) fd, err := t.NewFDFrom(0, event, kernel.FDFlags{ CloseOnExec: flags&linux.EFD_CLOEXEC != 0, diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 2797c6a72..98331eb3c 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" @@ -40,7 +41,7 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent, // Common case: we are accessing a file in the root. root := t.FSContext().RootDirectory() err := fn(root, root, name, linux.MaxSymlinkTraversals) - root.DecRef() + root.DecRef(t) return err } else if dir == "." && dirFD == linux.AT_FDCWD { // Common case: we are accessing a file relative to the current @@ -48,8 +49,8 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent, wd := t.FSContext().WorkingDirectory() root := t.FSContext().RootDirectory() err := fn(root, wd, name, linux.MaxSymlinkTraversals) - wd.DecRef() - root.DecRef() + wd.DecRef(t) + root.DecRef(t) return err } @@ -97,19 +98,19 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro } else { d, err = t.MountNamespace().FindLink(t, root, rel, path, &remainingTraversals) } - root.DecRef() + root.DecRef(t) if wd != nil { - wd.DecRef() + wd.DecRef(t) } if f != nil { - f.DecRef() + f.DecRef(t) } if err != nil { return err } err = fn(root, d, remainingTraversals) - d.DecRef() + d.DecRef(t) return err } @@ -184,9 +185,9 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint file, err := d.Inode.GetFile(t, d, fileFlags) if err != nil { - return syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } - defer file.DecRef() + defer file.DecRef(t) // Success. newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ @@ -242,7 +243,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode if err != nil { return err } - file.DecRef() + file.DecRef(t) return nil case linux.ModeNamedPipe: @@ -332,7 +333,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l if err != nil { break } - defer found.DecRef() + defer found.DecRef(t) // We found something (possibly a symlink). If the // O_EXCL flag was passed, then we can immediately @@ -357,7 +358,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l resolved, err = found.Inode.Getlink(t) if err == nil { // No more resolution necessary. - defer resolved.DecRef() + defer resolved.DecRef(t) break } if err != fs.ErrResolveViaReadlink { @@ -384,7 +385,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l if err != nil { break } - defer newParent.DecRef() + defer newParent.DecRef(t) // Repeat the process with the parent and name of the // symlink target. @@ -414,9 +415,9 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l // Create a new fs.File. newFile, err = found.Inode.GetFile(t, found, fileFlags) if err != nil { - return syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } - defer newFile.DecRef() + defer newFile.DecRef(t) case syserror.ENOENT: // File does not exist. Proceed with creation. @@ -432,7 +433,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l // No luck, bail. return err } - defer newFile.DecRef() + defer newFile.DecRef(t) found = newFile.Dirent default: return err @@ -596,24 +597,24 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Shared flags between file and socket. switch request { case linux.FIONCLEX: - t.FDTable().SetFlags(fd, kernel.FDFlags{ + t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: false, }) return 0, nil, nil case linux.FIOCLEX: - t.FDTable().SetFlags(fd, kernel.FDFlags{ + t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: true, }) return 0, nil, nil case linux.FIONBIO: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.Flags() @@ -627,7 +628,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOASYNC: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.Flags() @@ -641,15 +642,14 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOSETOWN, linux.SIOCSPGRP: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } fSetOwn(t, file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: - who := fGetOwn(t, file) - _, err := t.CopyOut(args[2].Pointer(), &who) + _, err := primitive.CopyInt32Out(t, args[2].Pointer(), fGetOwn(t, file)) return 0, nil, err default: @@ -671,9 +671,9 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal addr := args[0].Pointer() size := args[1].SizeT() cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Get our fullname from the root and preprend unreachable if the root was // unreachable from our current dirent this is the same behavior as on linux. @@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Top it off with a terminator. - _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00")) + _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00")) return uintptr(bytes + 1), nil, err } @@ -722,7 +722,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return err } - t.FSContext().SetRootDirectory(d) + t.FSContext().SetRootDirectory(t, d) return nil }) } @@ -747,7 +747,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return err } - t.FSContext().SetWorkingDirectory(d) + t.FSContext().SetWorkingDirectory(t, d) return nil }) } @@ -760,7 +760,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Is it a directory? if !fs.IsDir(file.Dirent.Inode.StableAttr) { @@ -772,7 +772,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } - t.FSContext().SetWorkingDirectory(file.Dirent) + t.FSContext().SetWorkingDirectory(t, file.Dirent) return 0, nil, nil } @@ -787,11 +787,11 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that Remove provides a reference on the file that we may use to // flush. It is still active until we drop the final reference below // (and other reference-holding operations complete). - file, _ := t.FDTable().Remove(fd) + file, _ := t.FDTable().Remove(t, fd) if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.Flush(t) return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file) @@ -805,7 +805,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{}) if err != nil { @@ -826,7 +826,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if oldFile == nil { return 0, nil, syserror.EBADF } - defer oldFile.DecRef() + defer oldFile.DecRef(t) return uintptr(newfd), nil, nil } @@ -850,7 +850,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if oldFile == nil { return 0, nil, syserror.EBADF } - defer oldFile.DecRef() + defer oldFile.DecRef(t) err := t.NewFDAt(newfd, oldFile, kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0}) if err != nil { @@ -925,7 +925,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) switch cmd { case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: @@ -941,7 +941,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - err := t.FDTable().SetFlags(fd, kernel.FDFlags{ + err := t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) return 0, nil, err @@ -962,7 +962,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock - if _, err := t.CopyIn(flockAddr, &flock); err != nil { + if _, err := flock.CopyIn(t, flockAddr); err != nil { return 0, nil, err } @@ -1052,12 +1052,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) - _, err := t.CopyOut(addr, &owner) + _, err := owner.CopyOut(t, addr) return 0, nil, err case linux.F_SETOWN_EX: addr := args[2].Pointer() var owner linux.FOwnerEx - n, err := t.CopyIn(addr, &owner) + _, err := owner.CopyIn(t, addr) if err != nil { return 0, nil, err } @@ -1069,21 +1069,21 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.ESRCH } a.SetOwnerTask(t, task) - return uintptr(n), nil, nil + return 0, nil, nil case linux.F_OWNER_PID: tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID)) if tg == nil { return 0, nil, syserror.ESRCH } a.SetOwnerThreadGroup(t, tg) - return uintptr(n), nil, nil + return 0, nil, nil case linux.F_OWNER_PGRP: pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID)) if pg == nil { return 0, nil, syserror.ESRCH } a.SetOwnerProcessGroup(t, pg) - return uintptr(n), nil, nil + return 0, nil, nil default: return 0, nil, syserror.EINVAL } @@ -1132,7 +1132,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // If the FD refers to a pipe or FIFO, return error. if fs.IsPipe(file.Dirent.Inode.StableAttr) { @@ -1154,6 +1154,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.ThenChange(vfs2/fd.go) + +// LINT.IfChange + func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1171,7 +1175,7 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode switch err { case nil: // The directory existed. - defer f.DecRef() + defer f.DecRef(t) return syserror.EEXIST case syserror.EACCES: // Permission denied while walking to the directory. @@ -1349,7 +1353,7 @@ func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32 if target == nil { return syserror.EBADF } - defer target.DecRef() + defer target.DecRef(t) if err := mayLinkAt(t, target.Dirent.Inode); err != nil { return err } @@ -1602,7 +1606,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Reject truncation if the file flags do not permit this operation. // This is different from truncate(2) above. @@ -1730,7 +1734,7 @@ func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bo if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return chown(t, file.Dirent, uid, gid) } @@ -1768,7 +1772,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, chown(t, file.Dirent, uid, gid) } @@ -1833,7 +1837,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, chmod(t, file.Dirent, mode) } @@ -1893,10 +1897,10 @@ func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, reso if f == nil { return syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) return setTimestamp(root, f.Dirent, linux.MaxSymlinkTraversals) } @@ -1918,7 +1922,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times linux.Utime - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := times.CopyIn(t, timesAddr); err != nil { return 0, nil, err } ts = fs.TimeSpec{ @@ -1938,7 +1942,7 @@ func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } ts = fs.TimeSpec{ @@ -1966,7 +1970,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timespec - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) { @@ -2000,7 +2004,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } if times[0].Usec >= 1e6 || times[0].Usec < 0 || @@ -2088,7 +2092,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) if offset < 0 || length <= 0 { return 0, nil, syserror.EINVAL @@ -2141,7 +2145,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // flock(2): EBADF fd is not an open file descriptor. return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) nonblocking := operation&linux.LOCK_NB != 0 operation &^= linux.LOCK_NB @@ -2224,8 +2228,8 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } - defer dirent.DecRef() - defer file.DecRef() + defer dirent.DecRef(t) + defer file.DecRef(t) newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: cloExec, diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b68261f72..f39ce0639 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -73,8 +73,8 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo err = t.BlockWithDeadline(w.C, true, ktime.FromTimespec(ts)) } - t.Futex().WaitComplete(w) - return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + t.Futex().WaitComplete(w, t) + return 0, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // futexWaitDuration performs a FUTEX_WAIT, blocking until the wait is @@ -95,7 +95,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add } remaining, err := t.BlockWithTimeout(w.C, !forever, duration) - t.Futex().WaitComplete(w) + t.Futex().WaitComplete(w, t) if err == nil { return 0, nil } @@ -110,7 +110,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add // The wait duration was absolute, restart with the original arguments. if forever { - return 0, kernel.ERESTARTSYS + return 0, syserror.ERESTARTSYS } // The wait duration was relative, restart with the remaining duration. @@ -121,7 +121,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add val: val, mask: mask, }) - return 0, kernel.ERESTART_RESTARTBLOCK + return 0, syserror.ERESTART_RESTARTBLOCK } func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.Addr, private bool) error { @@ -148,8 +148,8 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A timer.Destroy() } - t.Futex().WaitComplete(w) - return syserror.ConvertIntr(err, kernel.ERESTARTSYS) + t.Futex().WaitComplete(w, t) + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } func tryLockPI(t *kernel.Task, addr usermem.Addr, private bool) error { @@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall switch cmd { case linux.FUTEX_WAIT: // WAIT uses a relative timeout. - mask = ^uint32(0) + mask = linux.FUTEX_BITSET_MATCH_ANY var timeoutDur time.Duration if !forever { timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond @@ -286,3 +286,53 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.ENOSYS } } + +// SetRobustList implements linux syscall set_robust_list(2). +func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + head := args[0].Pointer() + length := args[1].SizeT() + + if length != uint(linux.SizeOfRobustListHead) { + return 0, nil, syserror.EINVAL + } + t.SetRobustList(head) + return 0, nil, nil +} + +// GetRobustList implements linux syscall get_robust_list(2). +func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + tid := args[0].Int() + headAddr := args[1].Pointer() + sizeAddr := args[2].Pointer() + + if tid < 0 { + return 0, nil, syserror.EINVAL + } + + ot := t + if tid != 0 { + if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil { + return 0, nil, syserror.ESRCH + } + } + + // Copy out head pointer. + head := t.Arch().Native(uintptr(ot.GetRobustList())) + if _, err := head.CopyOut(t, headAddr); err != nil { + return 0, nil, err + } + + // Copy out size, which is a constant. Note that while size isn't + // an address, it is defined as the arch-dependent size_t, so it + // needs to be converted to a native-sized int. + size := t.Arch().Native(uintptr(linux.SizeOfRobustListHead)) + if _, err := size.CopyOut(t, sizeAddr); err != nil { + return 0, nil, err + } + + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index b126fecc0..b25f7d881 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -19,7 +19,6 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -68,7 +67,7 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir if dir == nil { return 0, syserror.EBADF } - defer dir.DecRef() + defer dir.DecRef(t) w := &usermem.IOReadWriter{ Ctx: t, @@ -82,7 +81,7 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir ds := newDirentSerializer(f, w, t.Arch(), size) rerr := dir.Readdir(t, ds) - switch err := handleIOError(t, ds.Written() > 0, rerr, kernel.ERESTARTSYS, "getdents", dir); err { + switch err := handleIOError(t, ds.Written() > 0, rerr, syserror.ERESTARTSYS, "getdents", dir); err { case nil: dir.Dirent.InotifyEvent(linux.IN_ACCESS, 0) return uintptr(ds.Written()), nil @@ -93,19 +92,23 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir } } -// oldDirentHdr is a fixed sized header matching the fixed size -// fields found in the old linux dirent struct. +// oldDirentHdr is a fixed sized header matching the fixed size fields found in +// the old linux dirent struct. +// +// +marshal type oldDirentHdr struct { Ino uint64 Off uint64 - Reclen uint16 + Reclen uint16 `marshal:"unaligned"` // Struct ends mid-word. } -// direntHdr is a fixed sized header matching the fixed size -// fields found in the new linux dirent struct. +// direntHdr is a fixed sized header matching the fixed size fields found in the +// new linux dirent struct. +// +// +marshal type direntHdr struct { OldHdr oldDirentHdr - Typ uint8 + Typ uint8 `marshal:"unaligned"` // Struct ends mid-word. } // dirent contains the data pointed to by a new linux dirent struct. @@ -134,20 +137,20 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent // the old linux dirent format. func smallestDirent(a arch.Context) uint { d := dirent{} - return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1 + return uint(d.Hdr.OldHdr.SizeBytes()) + a.Width() + 1 } // smallestDirent64 returns the size of the smallest possible dirent using // the new linux dirent format. func smallestDirent64(a arch.Context) uint { d := dirent{} - return uint(binary.Size(d.Hdr)) + a.Width() + return uint(d.Hdr.SizeBytes()) + a.Width() } // padRec pads the name field until the rec length is a multiple of the width, // which must be a power of 2. It returns the padded rec length. func (d *dirent) padRec(width int) uint16 { - a := int(binary.Size(d.Hdr)) + len(d.Name) + a := d.Hdr.SizeBytes() + len(d.Name) r := (a + width) &^ (width - 1) padding := r - a d.Name = append(d.Name, make([]byte, padding)...) @@ -157,7 +160,7 @@ func (d *dirent) padRec(width int) uint16 { // Serialize64 serializes a Dirent struct to a byte slice, keeping the new // linux dirent format. Returns the number of bytes serialized or an error. func (d *dirent) Serialize64(w io.Writer) (int, error) { - n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr)) + n1, err := d.Hdr.WriteTo(w) if err != nil { return 0, err } @@ -165,14 +168,14 @@ func (d *dirent) Serialize64(w io.Writer) (int, error) { if err != nil { return 0, err } - return n1 + n2, nil + return int(n1) + n2, nil } // Serialize serializes a Dirent struct to a byte slice, using the old linux // dirent format. // Returns the number of bytes serialized or an error. func (d *dirent) Serialize(w io.Writer) (int, error) { - n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr)) + n1, err := d.Hdr.OldHdr.WriteTo(w) if err != nil { return 0, err } @@ -184,7 +187,7 @@ func (d *dirent) Serialize(w io.Writer) (int, error) { if err != nil { return 0, err } - return n1 + n2 + n3, nil + return int(n1) + n2 + n3, nil } // direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go index 715ac45e6..a29d307e5 100644 --- a/pkg/sentry/syscalls/linux/sys_identity.go +++ b/pkg/sentry/syscalls/linux/sys_identity.go @@ -49,13 +49,13 @@ func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ruid := c.RealKUID.In(c.UserNamespace).OrOverflow() euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow() suid := c.SavedKUID.In(c.UserNamespace).OrOverflow() - if _, err := t.CopyOut(ruidAddr, ruid); err != nil { + if _, err := ruid.CopyOut(t, ruidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(euidAddr, euid); err != nil { + if _, err := euid.CopyOut(t, euidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(suidAddr, suid); err != nil { + if _, err := suid.CopyOut(t, suidAddr); err != nil { return 0, nil, err } return 0, nil, nil @@ -84,13 +84,13 @@ func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys rgid := c.RealKGID.In(c.UserNamespace).OrOverflow() egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow() sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow() - if _, err := t.CopyOut(rgidAddr, rgid); err != nil { + if _, err := rgid.CopyOut(t, rgidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(egidAddr, egid); err != nil { + if _, err := egid.CopyOut(t, egidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(sgidAddr, sgid); err != nil { + if _, err := sgid.CopyOut(t, sgidAddr); err != nil { return 0, nil, err } return 0, nil, nil @@ -157,7 +157,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys for i, kgid := range kgids { gids[i] = kgid.In(t.UserNamespace()).OrOverflow() } - if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil { + if _, err := auth.CopyGIDSliceOut(t, args[1].Pointer(), gids); err != nil { return 0, nil, err } return uintptr(len(gids)), nil, nil @@ -173,7 +173,7 @@ func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, t.SetExtraGIDs(nil) } gids := make([]auth.GID, size) - if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil { + if _, err := auth.CopyGIDSliceIn(t, args[1].Pointer(), gids); err != nil { return 0, nil, err } return 0, nil, t.SetExtraGIDs(gids) diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go index b2c7b3444..cf47bb9dd 100644 --- a/pkg/sentry/syscalls/linux/sys_inotify.go +++ b/pkg/sentry/syscalls/linux/sys_inotify.go @@ -40,7 +40,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. NonBlocking: flags&linux.IN_NONBLOCK != 0, } n := fs.NewFile(t, dirent, fileFlags, fs.NewInotify(t)) - defer n.DecRef() + defer n.DecRef(t) fd, err := t.NewFDFrom(0, n, kernel.FDFlags{ CloseOnExec: flags&linux.IN_CLOEXEC != 0, @@ -71,7 +71,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) { ino, ok := file.FileOperations.(*fs.Inotify) if !ok { // Not an inotify fd. - file.DecRef() + file.DecRef(t) return nil, nil, syserror.EINVAL } @@ -98,7 +98,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -128,6 +128,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if err != nil { return 0, nil, err } - defer file.DecRef() - return 0, nil, ino.RmWatch(wd) + defer file.DecRef(t) + return 0, nil, ino.RmWatch(t, wd) } diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index 3f7691eae..0046347cb 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -33,7 +33,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) var sw fs.SeekWhence switch whence { @@ -48,7 +48,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } offset, serr := file.Seek(t, sw, offset) - err := handleIOError(t, false /* partialResult */, serr, kernel.ERESTARTSYS, "lseek", file) + err := handleIOError(t, false /* partialResult */, serr, syserror.ERESTARTSYS, "lseek", file) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 91694d374..cd8dfdfa4 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -75,7 +75,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } defer func() { if opts.MappingIdentity != nil { - opts.MappingIdentity.DecRef() + opts.MappingIdentity.DecRef(t) } }() @@ -85,7 +85,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) flags := file.Flags() // mmap unconditionally requires that the FD is readable. @@ -100,6 +100,15 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if err := file.ConfigureMMap(t, &opts); err != nil { return 0, nil, err } + } else if shared { + // Back shared anonymous mappings with a special mappable. + opts.Offset = 0 + m, err := mm.NewSharedAnonMappable(opts.Length, t.Kernel()) + if err != nil { + return 0, nil, err + } + opts.MappingIdentity = m // transfers ownership of m to opts + opts.Mappable = m } rv, err := t.MemoryManager().MMap(t, opts) @@ -239,7 +248,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, syserror.ENOMEM } resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize)) - _, err := t.CopyOut(vec, resident) + _, err := t.CopyOutBytes(vec, resident) return 0, nil, err } @@ -267,7 +276,7 @@ func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall }) // MSync calls fsync, the same interrupt conversion rules apply, see // mm/msync.c, fsync POSIX.1-2008. - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // Mlock implements linux syscall mlock(2). diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go index eb5ff48f5..bd0633564 100644 --- a/pkg/sentry/syscalls/linux/sys_mount.go +++ b/pkg/sentry/syscalls/linux/sys_mount.go @@ -115,7 +115,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall }); err != nil { // Something went wrong. Drop our ref on rootInode before // returning the error. - rootInode.DecRef() + rootInode.DecRef(t) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 43c510930..849a47476 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -34,10 +35,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize) r.SetFlags(linuxToFlags(flags).Settable()) - defer r.DecRef() + defer r.DecRef(t) w.SetFlags(linuxToFlags(flags).Settable()) - defer w.DecRef() + defer w.DecRef(t) fds, err := t.NewFDs(0, []*fs.File{r, w}, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -46,10 +47,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { return 0, err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if file, _ := t.FDTable().Remove(fd); file != nil { - file.DecRef() + if file, _ := t.FDTable().Remove(t, fd); file != nil { + file.DecRef(t) } } return 0, err diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index f0198141c..254f4c9f9 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -70,7 +70,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } if ch == nil { - defer file.DecRef() + defer file.DecRef(t) } else { state.file = file state.waiter, _ = waiter.NewChannelEntry(ch) @@ -82,11 +82,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } // releaseState releases all the pollState in "state". -func releaseState(state []pollState) { +func releaseState(t *kernel.Task, state []pollState) { for i := range state { if state[i].file != nil { state[i].file.EventUnregister(&state[i].waiter) - state[i].file.DecRef() + state[i].file.DecRef(t) } } } @@ -107,7 +107,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // result, we stop registering for events but still go through all files // to get their ready masks. state := make([]pollState, len(pfd)) - defer releaseState(state) + defer releaseState(t, state) n := uintptr(0) for i := range pfd { initReadiness(t, &pfd[i], &state[i], ch) @@ -162,7 +162,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD pfd := make([]linux.PollFD, nfds) if nfds > 0 { - if _, err := t.CopyIn(addr, &pfd); err != nil { + if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil { return nil, err } } @@ -189,7 +189,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. if nfds > 0 && err == nil { - if _, err := t.CopyOut(addr, pfd); err != nil { + if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil { return remainingTimeout, 0, err } } @@ -202,7 +202,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy set := make([]byte, nBytes) if addr != 0 { - if _, err := t.CopyIn(addr, &set); err != nil { + if _, err := t.CopyInBytes(addr, set); err != nil { return nil, err } // If we only use part of the last byte, mask out the extraneous bits. @@ -266,7 +266,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add if file == nil { return 0, syserror.EBADF } - file.DecRef() + file.DecRef(t) var mask int16 if (rV & m) != 0 { @@ -329,19 +329,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Copy updated vectors back. if readFDs != 0 { - if _, err := t.CopyOut(readFDs, r); err != nil { + if _, err := t.CopyOutBytes(readFDs, r); err != nil { return 0, err } } if writeFDs != 0 { - if _, err := t.CopyOut(writeFDs, w); err != nil { + if _, err := t.CopyOutBytes(writeFDs, w); err != nil { return 0, err } } if exceptFDs != 0 { - if _, err := t.CopyOut(exceptFDs, e); err != nil { + if _, err := t.CopyOutBytes(exceptFDs, e); err != nil { return 0, err } } @@ -410,7 +410,7 @@ func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration nfds: nfds, timeout: remainingTimeout, }) - return 0, kernel.ERESTART_RESTARTBLOCK + return 0, syserror.ERESTART_RESTARTBLOCK } return n, err } @@ -464,7 +464,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } @@ -494,7 +494,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } @@ -539,7 +539,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index f92bf8096..a892d2c62 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -43,7 +44,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, nil case linux.PR_GET_PDEATHSIG: - _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal())) + _, err := primitive.CopyInt32Out(t, args[1].Pointer(), int32(t.ParentDeathSignal())) return 0, nil, err case linux.PR_GET_DUMPABLE: @@ -110,7 +111,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall buf[len] = 0 len++ } - _, err := t.CopyOut(addr, buf[:len]) + _, err := t.CopyOutBytes(addr, buf[:len]) if err != nil { return 0, nil, err } @@ -128,7 +129,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // They trying to set exe to a non-file? if !fs.IsFile(file.Dirent.Inode.StableAttr) { @@ -136,7 +137,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } // Set the underlying executable. - t.MemoryManager().SetExecutable(fsbridge.NewFSFile(file)) + t.MemoryManager().SetExecutable(t, fsbridge.NewFSFile(file)) case linux.PR_SET_MM_AUXV, linux.PR_SET_MM_START_CODE, diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 071b4bacc..f655d3db1 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -48,7 +48,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.Flags().Read { @@ -71,7 +71,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "read", file) } // Readahead implements readahead(2). @@ -84,7 +84,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.Flags().Read { @@ -118,7 +118,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -151,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file) } // Readv implements linux syscall readv(2). @@ -164,7 +164,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.Flags().Read { @@ -181,7 +181,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "readv", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "readv", file) } // Preadv implements linux syscall preadv(2). @@ -195,7 +195,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -222,7 +222,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file) } // Preadv2 implements linux syscall preadv2(2). @@ -244,7 +244,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { @@ -280,12 +280,12 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if offset == -1 { n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) } n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) } func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index d5d5b6959..309c183a3 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -26,17 +27,13 @@ import ( // rlimit describes an implementation of 'struct rlimit', which may vary from // system-to-system. type rlimit interface { + marshal.Marshallable + // toLimit converts an rlimit to a limits.Limit. toLimit() *limits.Limit // fromLimit converts a limits.Limit to an rlimit. fromLimit(lim limits.Limit) - - // copyIn copies an rlimit from the untrusted app to the kernel. - copyIn(t *kernel.Task, addr usermem.Addr) error - - // copyOut copies an rlimit from the kernel to the untrusted app. - copyOut(t *kernel.Task, addr usermem.Addr) error } // newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system. @@ -50,6 +47,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) { } } +// +marshal type rlimit64 struct { Cur uint64 Max uint64 @@ -70,12 +68,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) { } func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error { - _, err := t.CopyIn(addr, r) + _, err := r.CopyIn(t, addr) return err } func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error { - _, err := t.CopyOut(addr, *r) + _, err := r.CopyOut(t, addr) return err } @@ -140,7 +138,8 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } rlim.fromLimit(lim) - return 0, nil, rlim.copyOut(t, addr) + _, err = rlim.CopyOut(t, addr) + return 0, nil, err } // Setrlimit implements linux syscall setrlimit(2). @@ -155,7 +154,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if err != nil { return 0, nil, err } - if err := rlim.copyIn(t, addr); err != nil { + if _, err := rlim.CopyIn(t, addr); err != nil { return 0, nil, syserror.EFAULT } _, err = prlimit64(t, resource, rlim.toLimit()) diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go index 1674c7445..ac5c98a54 100644 --- a/pkg/sentry/syscalls/linux/sys_rusage.go +++ b/pkg/sentry/syscalls/linux/sys_rusage.go @@ -80,7 +80,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } ru := getrusage(t, which) - _, err := t.CopyOut(addr, &ru) + _, err := ru.CopyOut(t, addr) return 0, nil, err } @@ -104,7 +104,7 @@ func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall CUTime: linux.ClockTFromDuration(cs2.UserTime), CSTime: linux.ClockTFromDuration(cs2.SysTime), } - if _, err := t.CopyOut(addr, &r); err != nil { + if _, err := r.CopyOut(t, addr); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go index 99f6993f5..bfcf44b6f 100644 --- a/pkg/sentry/syscalls/linux/sys_sched.go +++ b/pkg/sentry/syscalls/linux/sys_sched.go @@ -27,8 +27,10 @@ const ( ) // SchedParam replicates struct sched_param in sched.h. +// +// +marshal type SchedParam struct { - schedPriority int64 + schedPriority int32 } // SchedGetparam implements linux syscall sched_getparam(2). @@ -45,7 +47,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.ESRCH } r := SchedParam{schedPriority: onlyPriority} - if _, err := t.CopyOut(param, r); err != nil { + if _, err := r.CopyOut(t, param); err != nil { return 0, nil, err } @@ -79,7 +81,7 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke return 0, nil, syserror.ESRCH } var r SchedParam - if _, err := t.CopyIn(param, &r); err != nil { + if _, err := r.CopyIn(t, param); err != nil { return 0, nil, syserror.EINVAL } if r.schedPriority != onlyPriority { diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go index 5b7a66f4d..4fdb4463c 100644 --- a/pkg/sentry/syscalls/linux/sys_seccomp.go +++ b/pkg/sentry/syscalls/linux/sys_seccomp.go @@ -24,6 +24,8 @@ import ( ) // userSockFprog is equivalent to Linux's struct sock_fprog on amd64. +// +// +marshal type userSockFprog struct { // Len is the length of the filter in BPF instructions. Len uint16 @@ -33,7 +35,7 @@ type userSockFprog struct { // Filter is a user pointer to the struct sock_filter array that makes up // the filter program. Filter is a uint64 rather than a usermem.Addr // because usermem.Addr is actually uintptr, which is not a fixed-size - // type, and encoding/binary.Read objects to this. + // type. Filter uint64 } @@ -54,11 +56,11 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error { } var fprog userSockFprog - if _, err := t.CopyIn(addr, &fprog); err != nil { + if _, err := fprog.CopyIn(t, addr); err != nil { return err } filter := make([]linux.BPFInstruction, int(fprog.Len)) - if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil { + if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil { return err } compiledFilter, err := bpf.Compile(filter) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 5f54f2456..47dadb800 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -66,7 +67,7 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } ops := make([]linux.Sembuf, nsops) - if _, err := t.CopyIn(sembufAddr, ops); err != nil { + if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil { return 0, nil, err } @@ -116,8 +117,8 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case linux.IPC_SET: arg := args[3].Pointer() - s := linux.SemidDS{} - if _, err := t.CopyIn(arg, &s); err != nil { + var s linux.SemidDS + if _, err := s.CopyIn(t, arg); err != nil { return 0, nil, err } @@ -188,7 +189,7 @@ func setValAll(t *kernel.Task, id int32, array usermem.Addr) error { return syserror.EINVAL } vals := make([]uint16, set.Size()) - if _, err := t.CopyIn(array, vals); err != nil { + if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil { return err } creds := auth.CredentialsFromContext(t) @@ -217,7 +218,7 @@ func getValAll(t *kernel.Task, id int32, array usermem.Addr) error { if err != nil { return err } - _, err = t.CopyOut(array, vals) + _, err = primitive.CopyUint16SliceOut(t, array, vals) return err } diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index 4a8bc24a2..584064143 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -39,7 +39,7 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer segment.DecRef() + defer segment.DecRef(t) return uintptr(segment.ID), nil, nil } @@ -66,7 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, syserror.EINVAL } - defer segment.DecRef() + defer segment.DecRef(t) opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{ Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC, @@ -108,22 +108,22 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } - defer segment.DecRef() + defer segment.DecRef(t) stat, err := segment.IPCStat(t) if err == nil { - _, err = t.CopyOut(buf, stat) + _, err = stat.CopyOut(t, buf) } return 0, nil, err case linux.IPC_INFO: params := r.IPCInfo() - _, err := t.CopyOut(buf, params) + _, err := params.CopyOut(t, buf) return 0, nil, err case linux.SHM_INFO: info := r.ShmInfo() - _, err := t.CopyOut(buf, info) + _, err := info.CopyOut(t, buf) return 0, nil, err } @@ -132,20 +132,19 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } - defer segment.DecRef() + defer segment.DecRef(t) switch cmd { case linux.IPC_SET: var ds linux.ShmidDS - _, err = t.CopyIn(buf, &ds) - if err != nil { + if _, err = ds.CopyIn(t, buf); err != nil { return 0, nil, err } - err = segment.Set(t, &ds) + err := segment.Set(t, &ds) return 0, nil, err case linux.IPC_RMID: - segment.MarkDestroyed() + segment.MarkDestroyed(t) return 0, nil, nil case linux.SHM_LOCK, linux.SHM_UNLOCK: diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index d2b0012ae..e748d33d8 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -348,7 +348,7 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Pause implements linux syscall pause(2). func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND) + return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND) } // RtSigpending implements linux syscall rt_sigpending(2). @@ -496,7 +496,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. t.SetSavedSignalMask(oldmask) // Perform the wait. - return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND) + return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND) } // RestartSyscall implements the linux syscall restart_syscall(2). @@ -536,7 +536,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Is this a signalfd? if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok { @@ -553,7 +553,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) // Set appropriate flags. file.SetFlags(fs.SettableFileFlags{ diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..9feaca0da 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -65,10 +67,10 @@ const flagsOffset = 48 const sizeOfInt32 = 4 // messageHeader64Len is the length of a MessageHeader64 struct. -var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) +var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes()) // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. -var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) +var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes()) // baseRecvFlags are the flags that are accepted across recvmsg(2), // recvmmsg(2), and recvfrom(2). @@ -76,6 +78,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | // MessageHeader64 is the 64-bit representation of the msghdr struct used in // the recvmsg and sendmsg syscalls. +// +// +marshal type MessageHeader64 struct { // Name is the optional pointer to a network address buffer. Name uint64 @@ -104,30 +108,14 @@ type MessageHeader64 struct { // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in // the recvmmsg and sendmmsg syscalls. +// +// +marshal type multipleMessageHeader64 struct { msgHdr MessageHeader64 msgLen uint32 _ int32 } -// CopyInMessageHeader64 copies a message header from user to kernel memory. -func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { - b := t.CopyScratchBuffer(52) - if _, err := t.CopyInBytes(addr, b); err != nil { - return err - } - - msg.Name = usermem.ByteOrder.Uint64(b[0:]) - msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) - msg.Iov = usermem.ByteOrder.Uint64(b[16:]) - msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) - msg.Control = usermem.ByteOrder.Uint64(b[32:]) - msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) - msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) - - return nil -} - // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { @@ -146,10 +134,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { // Get the buffer length. var bufLen uint32 - if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { return err } @@ -158,7 +146,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Write the length unconditionally. - if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil { return err } @@ -171,7 +159,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Copy as much of the address as will fit in the buffer. - encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + encodedAddr := t.CopyScratchBuffer(addr.SizeBytes()) + addr.MarshalUnsafe(encodedAddr) if bufLen > uint32(len(encodedAddr)) { bufLen = uint32(len(encodedAddr)) } @@ -198,7 +187,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal s.SetFlags(fs.SettableFileFlags{ NonBlocking: stype&linux.SOCK_NONBLOCK != 0, }) - defer s.DecRef() + defer s.DecRef(t) fd, err := t.NewFDFrom(0, s, kernel.FDFlags{ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0, @@ -233,8 +222,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } s1.SetFlags(fileFlags) s2.SetFlags(fileFlags) - defer s1.DecRef() - defer s2.DecRef() + defer s1.DecRef(t) + defer s2.DecRef(t) // Create the FDs for the sockets. fds, err := t.NewFDs(0, []*fs.File{s1, s2}, kernel.FDFlags{ @@ -245,10 +234,10 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Copy the file descriptors out. - if _, err := t.CopyOut(socks, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, socks, fds); err != nil { for _, fd := range fds { - if file, _ := t.FDTable().Remove(fd); file != nil { - file.DecRef() + if file, _ := t.FDTable().Remove(t, fd); file != nil { + file.DecRef(t) } } return 0, nil, err @@ -268,7 +257,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -283,7 +272,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } blocking := !file.Flags().NonBlocking - return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS) } // accept is the implementation of the accept syscall. It is called by accept @@ -299,7 +288,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -314,7 +303,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f peerRequested := addrLen != 0 nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } if peerRequested { // NOTE(magi): Linux does not give you an error if it can't @@ -358,7 +347,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -385,7 +374,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -414,7 +403,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -445,7 +434,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -454,8 +443,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Read the length. Reject negative values. - optLen := int32(0) - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + var optLen int32 + if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil { return 0, nil, err } if optLen < 0 { @@ -469,12 +458,12 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +473,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +485,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -524,7 +516,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -539,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } @@ -562,7 +554,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -590,7 +582,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -623,7 +615,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -676,7 +668,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -728,7 +720,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -743,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -766,16 +758,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { - return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS) } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Release() + cms.Release(t) } if int(msg.Flags) != mflags { // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -788,9 +780,9 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } - defer cms.Release() + defer cms.Release(t) controlData := make([]byte, 0, msg.ControlLen) controlData = control.PackControlMessages(t, cms, controlData) @@ -812,17 +804,17 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } // Copy the control data to the caller. - if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -846,7 +838,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -875,9 +867,9 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Release() + cm.Release(t) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } // Copy the address to the caller. @@ -919,7 +911,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -957,7 +949,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -991,7 +983,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -1006,7 +998,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -1017,7 +1009,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -1059,9 +1051,9 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) - err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) + err = handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) if err != nil { - controlMessages.Release() + controlMessages.Release(t) } return uintptr(n), err } @@ -1079,7 +1071,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -1119,7 +1111,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) - return uintptr(n), handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) + return uintptr(n), handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file) } // SendTo implements the linux syscall sendto(2). diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 77c78889d..46616c961 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -101,7 +102,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) if !inFile.Flags().Read { return 0, nil, syserror.EBADF @@ -111,7 +112,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) if !outFile.Flags().Write { return 0, nil, syserror.EBADF @@ -141,7 +142,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Copy in the offset. var offset int64 - if _, err := t.CopyIn(offsetAddr, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, offsetAddr, &offset); err != nil { return 0, nil, err } @@ -149,11 +150,11 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, SrcOffset: true, - SrcStart: offset, + SrcStart: int64(offset), }, outFile.Flags().NonBlocking) // Copy out the new offset. - if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { + if _, err := primitive.CopyInt64Out(t, offsetAddr, offset+n); err != nil { return 0, nil, err } } else { @@ -170,7 +171,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -192,13 +193,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) inFile := t.GetFile(inFD) if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) // The operation is non-blocking if anything is non-blocking. // @@ -228,7 +229,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } var offset int64 - if _, err := t.CopyIn(outOffset, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, outOffset, &offset); err != nil { return 0, nil, err } @@ -246,7 +247,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } var offset int64 - if _, err := t.CopyIn(inOffset, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, inOffset, &offset); err != nil { return 0, nil, err } @@ -280,7 +281,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "splice", inFile) } // Tee imlements tee(2). @@ -300,13 +301,13 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) inFile := t.GetFile(inFD) if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) // All files must be pipes. if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) { @@ -333,5 +334,5 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 46ebf27a2..cda29a8b5 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -58,7 +58,7 @@ func Fstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, fstat(t, file, statAddr) } @@ -100,7 +100,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, fstat(t, file, statAddr) } @@ -158,7 +158,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) uattr, err := file.UnstableAttr(t) if err != nil { return 0, nil, err @@ -221,7 +221,7 @@ func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr DevMajor: uint32(devMajor), DevMinor: devMinor, } - _, err := t.CopyOut(statxAddr, &s) + _, err := s.CopyOut(t, statxAddr) return err } @@ -249,7 +249,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, statfsImpl(t, file.Dirent, statfsAddr) } @@ -283,7 +283,7 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error { FragmentSize: d.Inode.StableAttr.BlockSize, // Leave other fields 0 like simple_statfs does. } - _, err = t.CopyOut(addr, &statfs) + _, err = statfs.CopyOut(t, addr) return err } diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index 5ad465ae3..048a21c6e 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -39,7 +39,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Use "sync-the-world" for now, it's guaranteed that fd is at least // on the root filesystem. @@ -54,10 +54,10 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll) - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // Fdatasync implements linux syscall fdatasync(2). @@ -70,10 +70,10 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData) - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // SyncFileRange implements linux syscall sync_file_rage(2) @@ -103,7 +103,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // SYNC_FILE_RANGE_WAIT_BEFORE waits upon write-out of all pages in the // specified range that have already been submitted to the device @@ -135,7 +135,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel err = file.Fsync(t, offset, fs.FileMaxOffset, fs.SyncData) } - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // LINT.ThenChange(vfs2/sync.go) diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go index 297de052a..674d341b6 100644 --- a/pkg/sentry/syscalls/linux/sys_sysinfo.go +++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go @@ -43,6 +43,6 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca FreeRAM: memFree, Unit: 1, } - _, err := t.CopyOut(addr, si) + _, err := si.CopyOut(t, addr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 00915fdde..39ca9ea97 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -19,6 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -117,7 +118,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) var wd *fs.Dirent var executable fsbridge.File @@ -133,7 +134,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) closeOnExec = fdFlags.CloseOnExec if atEmptyPath && len(pathname) == 0 { @@ -155,7 +156,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user } } if wd != nil { - defer wd.DecRef() + defer wd.DecRef(t) } // Load the new TaskContext. @@ -262,7 +263,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error { wopts.Events |= kernel.EventGroupContinue } if options&linux.WNOHANG == 0 { - wopts.BlockInterruptErr = kernel.ERESTARTSYS + wopts.BlockInterruptErr = syserror.ERESTARTSYS } if options&linux.WNOTHREAD == 0 { wopts.SiblingChildren = true @@ -311,13 +312,13 @@ func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusage return 0, err } if statusAddr != 0 { - if _, err := t.CopyOut(statusAddr, wr.Status); err != nil { + if _, err := primitive.CopyUint32Out(t, statusAddr, wr.Status); err != nil { return 0, err } } if rusageAddr != 0 { ru := getrusage(wr.Task, linux.RUSAGE_BOTH) - if _, err := t.CopyOut(rusageAddr, &ru); err != nil { + if _, err := ru.CopyOut(t, rusageAddr); err != nil { return 0, err } } @@ -395,14 +396,14 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // as well. if infop != 0 { var si arch.SignalInfo - _, err = t.CopyOut(infop, &si) + _, err = si.CopyOut(t, infop) } } return 0, nil, err } if rusageAddr != 0 { ru := getrusage(wr.Task, linux.RUSAGE_BOTH) - if _, err := t.CopyOut(rusageAddr, &ru); err != nil { + if _, err := ru.CopyOut(t, rusageAddr); err != nil { return 0, nil, err } } @@ -441,7 +442,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal default: t.Warningf("waitid got incomprehensible wait status %d", s) } - _, err = t.CopyOut(infop, &si) + _, err = si.CopyOut(t, infop) return 0, nil, err } @@ -558,9 +559,7 @@ func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // third argument to this system call is nowadays unused. if cpu != 0 { - buf := t.CopyScratchBuffer(4) - usermem.ByteOrder.PutUint32(buf, uint32(t.CPU())) - if _, err := t.CopyOutBytes(cpu, buf); err != nil { + if _, err := primitive.CopyInt32Out(t, cpu, t.CPU()); err != nil { return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 2d2aa0819..c5054d2f1 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -168,7 +169,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(r), nil, nil } - if _, err := t.CopyOut(addr, r); err != nil { + if _, err := r.CopyOut(t, addr); err != nil { return 0, nil, err } return uintptr(r), nil, nil @@ -213,7 +214,7 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error return nil } - return syserror.ConvertIntr(err, kernel.ERESTARTNOHAND) + return syserror.ConvertIntr(err, syserror.ERESTARTNOHAND) } // clockNanosleepFor blocks for a specified duration. @@ -254,7 +255,7 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use duration: remaining, rem: rem, }) - return kernel.ERESTART_RESTARTBLOCK + return syserror.ERESTART_RESTARTBLOCK default: panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } @@ -334,8 +335,8 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // Ask the time package for the timezone. _, offset := time.Now().Zone() // This int32 array mimics linux's struct timezone. - timezone := [2]int32{-int32(offset) / 60, 0} - _, err := t.CopyOut(tz, timezone) + timezone := []int32{-int32(offset) / 60, 0} + _, err := primitive.CopyInt32SliceOut(t, tz, timezone) return 0, nil, err } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go index a4c400f87..45eef4feb 100644 --- a/pkg/sentry/syscalls/linux/sys_timer.go +++ b/pkg/sentry/syscalls/linux/sys_timer.go @@ -21,81 +21,63 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const nsecPerSec = int64(time.Second) -// copyItimerValIn copies an ItimerVal from the untrusted app range to the -// kernel. The ItimerVal may be either 32 or 64 bits. -// A NULL address is allowed because because Linux allows -// setitimer(which, NULL, &old_value) which disables the timer. -// There is a KERN_WARN message saying this misfeature will be removed. -// However, that hasn't happened as of 3.19, so we continue to support it. -func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) { - if addr == usermem.Addr(0) { - return linux.ItimerVal{}, nil - } - - switch t.Arch().Width() { - case 8: - // Native size, just copy directly. - var itv linux.ItimerVal - if _, err := t.CopyIn(addr, &itv); err != nil { - return linux.ItimerVal{}, err - } - - return itv, nil - default: - return linux.ItimerVal{}, syserror.ENOSYS - } -} - -// copyItimerValOut copies an ItimerVal to the untrusted app range. -// The ItimerVal may be either 32 or 64 bits. -// A NULL address is allowed, in which case no copy takes place -func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error { - if addr == usermem.Addr(0) { - return nil - } - - switch t.Arch().Width() { - case 8: - // Native size, just copy directly. - _, err := t.CopyOut(addr, itv) - return err - default: - return syserror.ENOSYS - } -} - // Getitimer implements linux syscall getitimer(2). func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + if t.Arch().Width() != 8 { + // Definition of linux.ItimerVal assumes 64-bit architecture. + return 0, nil, syserror.ENOSYS + } + timerID := args[0].Int() - val := args[1].Pointer() + addr := args[1].Pointer() olditv, err := t.Getitimer(timerID) if err != nil { return 0, nil, err } - return 0, nil, copyItimerValOut(t, val, &olditv) + // A NULL address is allowed, in which case no copy out takes place. + if addr == 0 { + return 0, nil, nil + } + _, err = olditv.CopyOut(t, addr) + return 0, nil, err } // Setitimer implements linux syscall setitimer(2). func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - timerID := args[0].Int() - newVal := args[1].Pointer() - oldVal := args[2].Pointer() + if t.Arch().Width() != 8 { + // Definition of linux.ItimerVal assumes 64-bit architecture. + return 0, nil, syserror.ENOSYS + } - newitv, err := copyItimerValIn(t, newVal) - if err != nil { - return 0, nil, err + timerID := args[0].Int() + newAddr := args[1].Pointer() + oldAddr := args[2].Pointer() + + var newitv linux.ItimerVal + // A NULL address is allowed because because Linux allows + // setitimer(which, NULL, &old_value) which disables the timer. There is a + // KERN_WARN message saying this misfeature will be removed. However, that + // hasn't happened as of 3.19, so we continue to support it. + if newAddr != 0 { + if _, err := newitv.CopyIn(t, newAddr); err != nil { + return 0, nil, err + } } olditv, err := t.Setitimer(timerID, newitv) if err != nil { return 0, nil, err } - return 0, nil, copyItimerValOut(t, oldVal, &olditv) + // A NULL address is allowed, in which case no copy out takes place. + if oldAddr == 0 { + return 0, nil, nil + } + _, err = olditv.CopyOut(t, oldAddr) + return 0, nil, err } // Alarm implements linux syscall alarm(2). @@ -131,7 +113,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S var sev *linux.Sigevent if sevp != 0 { sev = &linux.Sigevent{} - if _, err = t.CopyIn(sevp, sev); err != nil { + if _, err = sev.CopyIn(t, sevp); err != nil { return 0, nil, err } } @@ -141,7 +123,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } - if _, err := t.CopyOut(timerIDp, &id); err != nil { + if _, err := id.CopyOut(t, timerIDp); err != nil { t.IntervalTimerDelete(id) return 0, nil, err } @@ -157,7 +139,7 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. oldValAddr := args[3].Pointer() var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0) @@ -165,9 +147,8 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, err } if oldValAddr != 0 { - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { - return 0, nil, err - } + _, err = oldVal.CopyOut(t, oldValAddr) + return 0, nil, err } return 0, nil, nil } @@ -181,7 +162,7 @@ func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if err != nil { return 0, nil, err } - _, err = t.CopyOut(curValAddr, &curVal) + _, err = curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go index cf49b43db..cadd9d348 100644 --- a/pkg/sentry/syscalls/linux/sys_timerfd.go +++ b/pkg/sentry/syscalls/linux/sys_timerfd.go @@ -43,7 +43,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.EINVAL } f := timerfd.NewFile(t, c) - defer f.DecRef() + defer f.DecRef(t) f.SetFlags(fs.SettableFileFlags{ NonBlocking: flags&linux.TFD_NONBLOCK != 0, }) @@ -73,7 +73,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) tf, ok := f.FileOperations.(*timerfd.TimerOperations) if !ok { @@ -81,7 +81,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock()) @@ -91,7 +91,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, oldS := tf.SetTime(newS) if oldValAddr != 0 { oldVal := ktime.ItimerspecFromSetting(tm, oldS) - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + if _, err := oldVal.CopyOut(t, oldValAddr); err != nil { return 0, nil, err } } @@ -107,7 +107,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) tf, ok := f.FileOperations.(*timerfd.TimerOperations) if !ok { @@ -116,6 +116,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, s := tf.GetTime() curVal := ktime.ItimerspecFromSetting(tm, s) - _, err := t.CopyOut(curValAddr, &curVal) + _, err := curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go index b3eb96a1c..6ddd30d5c 100644 --- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go @@ -18,6 +18,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -30,17 +31,19 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys case linux.ARCH_GET_FS: addr := args[1].Pointer() fsbase := t.Arch().TLS() - _, err := t.CopyOut(addr, uint64(fsbase)) - if err != nil { - return 0, nil, err + switch t.Arch().Width() { + case 8: + if _, err := primitive.CopyUint64Out(t, addr, uint64(fsbase)); err != nil { + return 0, nil, err + } + default: + return 0, nil, syserror.ENOSYS } - case linux.ARCH_SET_FS: fsbase := args[1].Uint64() if !t.Arch().SetTLS(uintptr(fsbase)) { return 0, nil, syserror.EPERM } - case linux.ARCH_GET_GS, linux.ARCH_SET_GS: t.Kernel().EmitUnimplementedEvent(t) fallthrough diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index e9d702e8e..66c5974f5 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -46,7 +46,7 @@ func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Copy out the result. va := args[0].Pointer() - _, err := t.CopyOut(va, u) + _, err := u.CopyOut(t, va) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 6ec0de96e..95bfe6606 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -48,7 +48,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is writable. if !file.Flags().Write { @@ -71,7 +71,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "write", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "write", file) } // Pwrite64 implements linux syscall pwrite64(2). @@ -85,7 +85,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -118,7 +118,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file) } // Writev implements linux syscall writev(2). @@ -131,7 +131,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is writable. if !file.Flags().Write { @@ -148,7 +148,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "writev", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "writev", file) } // Pwritev implements linux syscall pwritev(2). @@ -162,7 +162,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -189,7 +189,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file) } // Pwritev2 implements linux syscall pwritev2(2). @@ -215,7 +215,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { @@ -250,12 +250,12 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if offset == -1 { n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) } n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) } func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index c24946160..97474fd3c 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -49,7 +49,7 @@ func FGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) n, err := getXattr(t, f.Dirent, nameAddr, valueAddr, size) if err != nil { @@ -153,7 +153,7 @@ func FSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) return 0, nil, setXattr(t, f.Dirent, nameAddr, valueAddr, uint64(size), flags) } @@ -270,7 +270,7 @@ func FListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) n, err := listXattr(t, f.Dirent, listAddr, size) if err != nil { @@ -384,7 +384,7 @@ func FRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) return 0, nil, removeXattr(t, f.Dirent, nameAddr) } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 0c740335b..9ee766552 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -44,6 +44,9 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/gohacks", + "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index e5cdefc50..6d0a38330 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -38,21 +39,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } for i := int32(0); i < nrEvents; i++ { - // Copy in the address. - cbAddrNative := t.Arch().Native(0) - if _, err := t.CopyIn(addr, cbAddrNative); err != nil { - if i > 0 { - // Some successful. - return uintptr(i), nil, nil + // Copy in the callback address. + var cbAddr usermem.Addr + switch t.Arch().Width() { + case 8: + var cbAddrP primitive.Uint64 + if _, err := cbAddrP.CopyIn(t, addr); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err } - // Nothing done. - return 0, nil, err + cbAddr = usermem.Addr(cbAddrP) + default: + return 0, nil, syserror.ENOSYS } // Copy in this callback. var cb linux.IOCallback - cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) - if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if _, err := cb.CopyIn(t, cbAddr); err != nil { if i > 0 { // Some have been successful. return uintptr(i), nil, nil @@ -88,7 +95,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if fd == nil { return syserror.EBADF } - defer fd.DecRef() + defer fd.DecRef(t) // Was there an eventFD? Extract it. var eventFD *vfs.FileDescription @@ -97,7 +104,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if eventFD == nil { return syserror.EBADF } - defer eventFD.DecRef() + defer eventFD.DecRef(t) // Check that it is an eventfd. if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok { @@ -144,6 +151,12 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback { return func(ctx context.Context) { + // Release references after completing the callback. + defer fd.DecRef(ctx) + if eventFD != nil { + defer eventFD.DecRef(ctx) + } + if aioCtx.Dead() { aioCtx.CancelPendingRequest() return @@ -169,8 +182,6 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use ev.Result = -int64(kernel.ExtractErrno(err, 0)) } - fd.DecRef() - // Queue the result for delivery. aioCtx.FinishRequest(ev) @@ -179,7 +190,6 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use // wake up. if eventFD != nil { eventFD.Impl().(*eventfd.EventFileDescription).Signal(1) - eventFD.DecRef() } } } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index 34c90ae3e..d0cbb77eb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -24,7 +24,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -37,11 +36,11 @@ func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, syserror.EINVAL } - file, err := t.Kernel().VFS().NewEpollInstanceFD() + file, err := t.Kernel().VFS().NewEpollInstanceFD(t) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0, @@ -62,11 +61,11 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - file, err := t.Kernel().VFS().NewEpollInstanceFD() + file, err := t.Kernel().VFS().NewEpollInstanceFD(t) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) if err != nil { @@ -86,7 +85,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if epfile == nil { return 0, nil, syserror.EBADF } - defer epfile.DecRef() + defer epfile.DecRef(t) ep, ok := epfile.Impl().(*vfs.EpollInstance) if !ok { return 0, nil, syserror.EINVAL @@ -95,7 +94,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) if epfile == file { return 0, nil, syserror.EINVAL } @@ -135,56 +134,32 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if epfile == nil { return 0, nil, syserror.EBADF } - defer epfile.DecRef() + defer epfile.DecRef(t) ep, ok := epfile.Impl().(*vfs.EpollInstance) if !ok { return 0, nil, syserror.EINVAL } - // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent, - // maxEvents), so that the buffer can be allocated on the stack. + // Allocate space for a few events on the stack for the common case in + // which we don't have too many events. var ( - events [16]linux.EpollEvent - total int + eventsArr [16]linux.EpollEvent ch chan struct{} haveDeadline bool deadline ktime.Time ) for { - batchEvents := len(events) - if batchEvents > maxEvents { - batchEvents = maxEvents - } - n := ep.ReadEvents(events[:batchEvents]) - maxEvents -= n - if n != 0 { - // Copy what we read out. - copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n]) + events := ep.ReadEvents(eventsArr[:0], maxEvents) + if len(events) != 0 { + copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events) copiedEvents := copiedBytes / sizeofEpollEvent // rounded down - eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent) - total += copiedEvents - if err != nil { - if total != 0 { - return uintptr(total), nil, nil - } - return 0, nil, err - } - // If we've filled the application's event buffer, we're done. - if maxEvents == 0 { - return uintptr(total), nil, nil - } - // Loop if we read a full batch, under the expectation that there - // may be more events to read. - if n == batchEvents { - continue + if copiedEvents != 0 { + return uintptr(copiedEvents), nil, nil } + return 0, nil, err } - // We get here if n != batchEvents. If we read any number of events - // (just now, or in a previous iteration of this loop), or if timeout - // is 0 (such that epoll_wait should be non-blocking), return the - // events we've read so far to the application. - if total != 0 || timeout == 0 { - return uintptr(total), nil, nil + if timeout == 0 { + return 0, nil, nil } // In the first iteration of this loop, register with the epoll // instance for readability events, but then immediately continue the @@ -207,8 +182,6 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if err == syserror.ETIMEDOUT { err = nil } - // total must be 0 since otherwise we would have returned - // above. return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go index aff1a2070..807f909da 100644 --- a/pkg/sentry/syscalls/linux/vfs2/eventfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go @@ -38,11 +38,11 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc fileFlags |= linux.O_NONBLOCK } semMode := flags&linux.EFD_SEMAPHORE != 0 - eventfd, err := eventfd.New(vfsObj, initVal, semMode, fileFlags) + eventfd, err := eventfd.New(t, vfsObj, initVal, semMode, fileFlags) if err != nil { return 0, nil, err } - defer eventfd.DecRef() + defer eventfd.DecRef(t) fd, err := t.NewFDFromVFS2(0, eventfd, kernel.FDFlags{ CloseOnExec: flags&linux.EFD_CLOEXEC != 0, diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index aef0078a8..066ee0863 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -71,7 +71,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user } root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) var executable fsbridge.File closeOnExec := false if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute { @@ -90,7 +90,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user } start := dirfile.VirtualDentry() start.IncRef() - dirfile.DecRef() + dirfile.DecRef(t) closeOnExec = dirfileFlags.CloseOnExec file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{ Root: root, @@ -101,19 +101,19 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user Flags: linux.O_RDONLY, FileExec: true, }) - start.DecRef() + start.DecRef(t) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) executable = fsbridge.NewVFSFile(file) } // Load the new TaskContext. mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change - defer mntns.DecRef() + defer mntns.DecRef(t) wd := t.FSContext().WorkingDirectoryVFS2() - defer wd.DecRef() + defer wd.DecRef(t) remainingTraversals := uint(linux.MaxSymlinkTraversals) loadArgs := loader.LoadArgs{ Opener: fsbridge.NewVFSLookup(mntns, root, wd), diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 517394ba9..d8b8d9783 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -34,11 +34,11 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that Remove provides a reference on the file that we may use to // flush. It is still active until we drop the final reference below // (and other reference-holding operations complete). - _, file := t.FDTable().Remove(fd) + _, file := t.FDTable().Remove(t, fd) if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.OnClose(t) return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file) @@ -52,7 +52,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) if err != nil { @@ -72,7 +72,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - file.DecRef() + file.DecRef(t) return uintptr(newfd), nil, nil } @@ -101,7 +101,7 @@ func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -121,7 +121,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) switch cmd { case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: @@ -137,7 +137,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + err := t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) return 0, nil, err @@ -181,15 +181,15 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if !hasOwner { return 0, nil, nil } - _, err := t.CopyOut(args[2].Pointer(), &owner) + _, err := owner.CopyOut(t, args[2].Pointer()) return 0, nil, err case linux.F_SETOWN_EX: var owner linux.FOwnerEx - n, err := t.CopyIn(args[2].Pointer(), &owner) + _, err := owner.CopyIn(t, args[2].Pointer()) if err != nil { return 0, nil, err } - return uintptr(n), nil, setAsyncOwner(t, file, owner.Type, owner.PID) + return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) case linux.F_GETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -208,7 +208,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_SETLK, linux.F_SETLKW: return 0, nil, posixLock(t, args, file, cmd) default: - // TODO(gvisor.dev/issue/2920): Everything else is not yet supported. + // Everything else is not yet supported. return 0, nil, syserror.EINVAL } } @@ -286,7 +286,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock - if _, err := t.CopyIn(flockAddr, &flock); err != nil { + if _, err := flock.CopyIn(t, flockAddr); err != nil { return err } @@ -332,7 +332,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // If the FD refers to a pipe or FIFO, return error. if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe { diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index b12b5967b..01e0f9010 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -57,7 +56,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd i if err != nil { return err } - defer oldtpop.Release() + defer oldtpop.Release(t) newpath, err := copyInPath(t, newpathAddr) if err != nil { @@ -67,7 +66,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd i if err != nil { return err } - defer newtpop.Release() + defer newtpop.Release(t) return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop) } @@ -96,7 +95,7 @@ func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error { if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()), }) @@ -107,7 +106,7 @@ func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall addr := args[0].Pointer() mode := args[1].ModeT() dev := args[2].Uint() - return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev) + return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev) } // Mknodat implements Linux syscall mknodat(2). @@ -116,10 +115,10 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca addr := args[1].Pointer() mode := args[2].ModeT() dev := args[3].Uint() - return 0, nil, mknodat(t, dirfd, addr, mode, dev) + return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev) } -func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error { +func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error { path, err := copyInPath(t, addr) if err != nil { return err @@ -128,10 +127,15 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + if mode.FileType() == 0 { + mode |= linux.ModeRegular + } major, minor := linux.DecodeDeviceID(dev) return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{ - Mode: linux.FileMode(mode &^ t.FSContext().Umask()), + Mode: mode &^ linux.FileMode(t.FSContext().Umask()), DevMajor: uint32(major), DevMinor: minor, }) @@ -170,7 +174,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{ Flags: flags | linux.O_LARGEFILE, @@ -179,7 +183,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -223,7 +227,7 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd if err != nil { return err } - defer oldtpop.Release() + defer oldtpop.Release(t) newpath, err := copyInPath(t, newpathAddr) if err != nil { @@ -233,62 +237,13 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd if err != nil { return err } - defer newtpop.Release() + defer newtpop.Release(t) return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{ Flags: flags, }) } -// Fallocate implements linux system call fallocate(2). -func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - fd := args[0].Int() - mode := args[1].Uint64() - offset := args[2].Int64() - length := args[3].Int64() - - file := t.GetFileVFS2(fd) - - if file == nil { - return 0, nil, syserror.EBADF - } - defer file.DecRef() - - if !file.IsWritable() { - return 0, nil, syserror.EBADF - } - - if mode != 0 { - return 0, nil, syserror.ENOTSUP - } - - if offset < 0 || length <= 0 { - return 0, nil, syserror.EINVAL - } - - size := offset + length - - if size < 0 { - return 0, nil, syserror.EFBIG - } - - limit := limits.FromContext(t).Get(limits.FileSize).Cur - - if uint64(size) >= limit { - t.SendSignal(&arch.SignalInfo{ - Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, - }) - return 0, nil, syserror.EFBIG - } - - return 0, nil, file.Impl().Allocate(t, mode, uint64(offset), uint64(length)) - - // File length modified, generate notification. - // TODO(gvisor.dev/issue/1479): Reenable when Inotify is ported. - // file.Dirent.InotifyEvent(linux.IN_MODIFY, 0) -} - // Rmdir implements Linux syscall rmdir(2). func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { pathAddr := args[0].Pointer() @@ -304,7 +259,7 @@ func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop) } @@ -323,7 +278,7 @@ func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop) } @@ -374,6 +329,6 @@ func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpath if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target) } diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go index 317409a18..a7d4d2a36 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fscontext.go +++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go @@ -31,8 +31,8 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal root := t.FSContext().RootDirectoryVFS2() wd := t.FSContext().WorkingDirectoryVFS2() s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd) - root.DecRef() - wd.DecRef() + root.DecRef(t) + wd.DecRef(t) if err != nil { return 0, nil, err } @@ -67,7 +67,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ CheckSearchable: true, @@ -75,8 +75,8 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - t.FSContext().SetWorkingDirectoryVFS2(vd) - vd.DecRef() + t.FSContext().SetWorkingDirectoryVFS2(t, vd) + vd.DecRef(t) return 0, nil, nil } @@ -88,7 +88,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ CheckSearchable: true, @@ -96,8 +96,8 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - t.FSContext().SetWorkingDirectoryVFS2(vd) - vd.DecRef() + t.FSContext().SetWorkingDirectoryVFS2(t, vd) + vd.DecRef(t) return 0, nil, nil } @@ -117,7 +117,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ CheckSearchable: true, @@ -125,7 +125,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - t.FSContext().SetRootDirectoryVFS2(vd) - vd.DecRef() + t.FSContext().SetRootDirectoryVFS2(t, vd) + vd.DecRef(t) return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index c7c7bf7ce..5517595b5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -44,7 +44,7 @@ func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (ui if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) cb := getGetdentsCallback(t, addr, size, isGetdents64) err := file.IterDirents(t, cb) diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go index 5d98134a5..11753d8e5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/inotify.go +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -35,7 +35,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if err != nil { return 0, nil, err } - defer ino.DecRef() + defer ino.DecRef(t) fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{ CloseOnExec: flags&linux.IN_CLOEXEC != 0, @@ -66,7 +66,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, ino, ok := f.Impl().(*vfs.Inotify) if !ok { // Not an inotify fd. - f.DecRef() + f.DecRef(t) return nil, nil, syserror.EINVAL } @@ -96,7 +96,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern if err != nil { return 0, nil, err } - defer f.DecRef() + defer f.DecRef(t) path, err := copyInPath(t, addr) if err != nil { @@ -109,12 +109,12 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{}) if err != nil { return 0, nil, err } - defer d.DecRef() + defer d.DecRef(t) fd, err = ino.AddWatch(d.Dentry(), mask) if err != nil { @@ -132,6 +132,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if err != nil { return 0, nil, err } - defer f.DecRef() - return 0, nil, ino.RmWatch(wd) + defer f.DecRef(t) + return 0, nil, ino.RmWatch(t, wd) } diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index fd6ab94b2..2806c3f6f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -29,25 +30,25 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Handle ioctls that apply to all FDs. switch args[1].Int() { case linux.FIONCLEX: - t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: false, }) return 0, nil, nil case linux.FIOCLEX: - t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: true, }) return 0, nil, nil case linux.FIONBIO: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.StatusFlags() @@ -60,7 +61,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOASYNC: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.StatusFlags() @@ -82,12 +83,12 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall who = owner.PID } } - _, err := t.CopyOut(args[2].Pointer(), &who) + _, err := primitive.CopyInt32Out(t, args[2].Pointer(), who) return 0, nil, err case linux.FIOSETOWN, linux.SIOCSPGRP: var who int32 - if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &who); err != nil { return 0, nil, err } ownerType := int32(linux.F_OWNER_PID) diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go index bf19028c4..b910b5a74 100644 --- a/pkg/sentry/syscalls/linux/vfs2/lock.go +++ b/pkg/sentry/syscalls/linux/vfs2/lock.go @@ -32,7 +32,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // flock(2): EBADF fd is not an open file descriptor. return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) nonblocking := operation&linux.LOCK_NB != 0 operation &^= linux.LOCK_NB diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go index bbe248d17..c4c0f9e0a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/memfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go @@ -47,10 +47,11 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S } shmMount := t.Kernel().ShmMount() - file, err := tmpfs.NewMemfd(shmMount, t.Credentials(), allowSeals, memfdPrefix+name) + file, err := tmpfs.NewMemfd(t, t.Credentials(), shmMount, allowSeals, memfdPrefix+name) if err != nil { return 0, nil, err } + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: cloExec, diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index 60a43f0a0..9d9dbf775 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" @@ -61,7 +62,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } defer func() { if opts.MappingIdentity != nil { - opts.MappingIdentity.DecRef() + opts.MappingIdentity.DecRef(t) } }() @@ -71,7 +72,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // mmap unconditionally requires that the FD is readable. if !file.IsReadable() { @@ -85,6 +86,17 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if err := file.ConfigureMMap(t, &opts); err != nil { return 0, nil, err } + } else if shared { + // Back shared anonymous mappings with an anonymous tmpfs file. + opts.Offset = 0 + file, err := tmpfs.NewZeroFile(t, t.Credentials(), t.Kernel().ShmMount(), opts.Length) + if err != nil { + return 0, nil, err + } + defer file.DecRef(t) + if err := file.ConfigureMMap(t, &opts); err != nil { + return 0, nil, err + } } rv, err := t.MemoryManager().MMap(t, opts) diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index adeaa39cc..769c9b92f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Silently allow MS_NOSUID, since we don't implement set-id bits // anyway. - const unsupportedFlags = linux.MS_NODEV | - linux.MS_NODIRATIME | linux.MS_STRICTATIME + const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME // Linux just allows passing any flags to mount(2) - it won't fail when // unknown or unsupported flags are passed. Since we don't implement @@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if flags&linux.MS_NOEXEC == linux.MS_NOEXEC { opts.Flags.NoExec = true } + if flags&linux.MS_NODEV == linux.MS_NODEV { + opts.Flags.NoDev = true + } + if flags&linux.MS_NOSUID == linux.MS_NOSUID { + opts.Flags.NoSUID = true + } if flags&linux.MS_RDONLY == linux.MS_RDONLY { opts.ReadOnly = true } @@ -103,9 +108,9 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - defer target.Release() - - return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) + defer target.Release(t) + _, err = t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) + return 0, nil, err } // Umount2 implements Linux syscall umount2(2). @@ -135,7 +140,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) opts := vfs.UmountOptions{ Flags: uint32(flags), diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go index 97da6c647..90a511d9a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/path.go +++ b/pkg/sentry/syscalls/linux/vfs2/path.go @@ -42,7 +42,7 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA haveStartRef := false if !path.Absolute { if !path.HasComponents() && !bool(shouldAllowEmptyPath) { - root.DecRef() + root.DecRef(t) return taskPathOperation{}, syserror.ENOENT } if dirfd == linux.AT_FDCWD { @@ -51,13 +51,13 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { - root.DecRef() + root.DecRef(t) return taskPathOperation{}, syserror.EBADF } start = dirfile.VirtualDentry() start.IncRef() haveStartRef = true - dirfile.DecRef() + dirfile.DecRef(t) } } return taskPathOperation{ @@ -71,10 +71,10 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA }, nil } -func (tpop *taskPathOperation) Release() { - tpop.pop.Root.DecRef() +func (tpop *taskPathOperation) Release(t *kernel.Task) { + tpop.pop.Root.DecRef(t) if tpop.haveStartRef { - tpop.pop.Start.DecRef() + tpop.pop.Start.DecRef(t) tpop.haveStartRef = false } } diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index 4a01e4209..ee38fdca0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -42,8 +43,8 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { return syserror.EINVAL } r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(t) + defer w.DecRef(t) fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -51,10 +52,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if err != nil { return err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if _, file := t.FDTable().Remove(fd); file != nil { - file.DecRef() + if _, file := t.FDTable().Remove(t, fd); file != nil { + file.DecRef(t) } } return err diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index ff1b25d7b..c22e4ce54 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -73,7 +73,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } if ch == nil { - defer file.DecRef() + defer file.DecRef(t) } else { state.file = file state.waiter, _ = waiter.NewChannelEntry(ch) @@ -85,11 +85,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } // releaseState releases all the pollState in "state". -func releaseState(state []pollState) { +func releaseState(t *kernel.Task, state []pollState) { for i := range state { if state[i].file != nil { state[i].file.EventUnregister(&state[i].waiter) - state[i].file.DecRef() + state[i].file.DecRef(t) } } } @@ -110,7 +110,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // result, we stop registering for events but still go through all files // to get their ready masks. state := make([]pollState, len(pfd)) - defer releaseState(state) + defer releaseState(t, state) n := uintptr(0) for i := range pfd { initReadiness(t, &pfd[i], &state[i], ch) @@ -165,7 +165,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD pfd := make([]linux.PollFD, nfds) if nfds > 0 { - if _, err := t.CopyIn(addr, &pfd); err != nil { + if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil { return nil, err } } @@ -192,7 +192,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. if nfds > 0 && err == nil { - if _, err := t.CopyOut(addr, pfd); err != nil { + if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil { return remainingTimeout, 0, err } } @@ -205,7 +205,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy set := make([]byte, nBytes) if addr != 0 { - if _, err := t.CopyIn(addr, &set); err != nil { + if _, err := t.CopyInBytes(addr, set); err != nil { return nil, err } // If we only use part of the last byte, mask out the extraneous bits. @@ -269,7 +269,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add if file == nil { return 0, syserror.EBADF } - file.DecRef() + file.DecRef(t) var mask int16 if (rV & m) != 0 { @@ -332,19 +332,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Copy updated vectors back. if readFDs != 0 { - if _, err := t.CopyOut(readFDs, r); err != nil { + if _, err := t.CopyOutBytes(readFDs, r); err != nil { return 0, err } } if writeFDs != 0 { - if _, err := t.CopyOut(writeFDs, w); err != nil { + if _, err := t.CopyOutBytes(writeFDs, w); err != nil { return 0, err } } if exceptFDs != 0 { - if _, err := t.CopyOut(exceptFDs, e); err != nil { + if _, err := t.CopyOutBytes(exceptFDs, e); err != nil { return 0, err } } @@ -415,7 +415,7 @@ func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration nfds: nfds, timeout: remainingTimeout, }) - return 0, kernel.ERESTART_RESTARTBLOCK + return 0, syserror.ERESTART_RESTARTBLOCK } return n, err } @@ -462,7 +462,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } @@ -492,11 +492,17 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } +// +marshal +type sigSetWithSize struct { + sigsetAddr uint64 + sizeofSigset uint64 +} + // Pselect implements linux syscall pselect(2). func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { nfds := int(args[0].Int()) // select(2) uses an int. @@ -533,17 +539,11 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } -// +marshal -type sigSetWithSize struct { - sigsetAddr uint64 - sizeofSigset uint64 -} - // copyTimespecInToDuration copies a Timespec from the untrusted app range, // validates it and converts it to a Duration. // diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index cd25597a7..b77b29dcc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -44,7 +44,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the size is legitimate. si := int(size) @@ -62,7 +62,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC n, err := read(t, file, dst, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "read", file) } // Readv implements Linux syscall readv(2). @@ -75,7 +75,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Get the destination of the read. dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ @@ -87,14 +87,14 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := read(t, file, dst, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "readv", file) } func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { n, err := file.Read(t, dst, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -102,7 +102,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -135,7 +135,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return total, err } @@ -151,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -174,7 +174,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file) } // Preadv implements Linux syscall preadv(2). @@ -188,7 +188,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -205,7 +205,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file) } // Preadv2 implements Linux syscall preadv2(2). @@ -226,7 +226,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { @@ -251,14 +251,14 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err = pread(t, file, dst, offset, opts) } t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) } func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { n, err := file.PRead(t, dst, offset, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -266,7 +266,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -299,7 +299,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return total, err } @@ -314,7 +314,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the size is legitimate. si := int(size) @@ -332,7 +332,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := write(t, file, src, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "write", file) } // Writev implements Linux syscall writev(2). @@ -345,7 +345,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Get the source of the write. src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ @@ -357,14 +357,14 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := write(t, file, src, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "writev", file) } func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { n, err := file.Write(t, src, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return n, err } @@ -372,7 +372,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return n, err } @@ -405,7 +405,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return total, err } @@ -421,7 +421,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -444,7 +444,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file) } // Pwritev implements Linux syscall pwritev(2). @@ -458,7 +458,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -475,7 +475,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file) } // Pwritev2 implements Linux syscall pwritev2(2). @@ -496,7 +496,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { @@ -521,14 +521,14 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err = pwrite(t, file, src, offset, opts) } t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) } func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { n, err := file.PWrite(t, src, offset, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return n, err } @@ -536,7 +536,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -569,7 +569,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return total, err } @@ -601,7 +601,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) newoff, err := file.Seek(t, offset, whence) return uintptr(newoff), nil, err @@ -617,7 +617,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.IsReadable() { diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 09ecfed26..1ee37e5a8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -65,7 +66,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ @@ -150,7 +151,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) var opts vfs.SetStatOptions if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil { @@ -178,6 +179,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Mask: linux.STATX_SIZE, Size: uint64(length), }, + NeedWritePerm: true, }) return 0, nil, handleSetSizeError(t, err) } @@ -195,7 +197,11 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) + + if !file.IsWritable() { + return 0, nil, syserror.EINVAL + } err := file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ @@ -206,6 +212,56 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, handleSetSizeError(t, err) } +// Fallocate implements linux system call fallocate(2). +func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + mode := args[1].Uint64() + offset := args[2].Int64() + length := args[3].Int64() + + file := t.GetFileVFS2(fd) + + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef(t) + + if !file.IsWritable() { + return 0, nil, syserror.EBADF + } + + if mode != 0 { + return 0, nil, syserror.ENOTSUP + } + + if offset < 0 || length <= 0 { + return 0, nil, syserror.EINVAL + } + + size := offset + length + + if size < 0 { + return 0, nil, syserror.EFBIG + } + + limit := limits.FromContext(t).Get(limits.FileSize).Cur + + if uint64(size) >= limit { + t.SendSignal(&arch.SignalInfo{ + Signo: int32(linux.SIGXFSZ), + Code: arch.SignalInfoUser, + }) + return 0, nil, syserror.EFBIG + } + + if err := file.Allocate(t, mode, uint64(offset), uint64(length)); err != nil { + return 0, nil, err + } + + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) + return 0, nil, nil +} + // Utime implements Linux syscall utime(2). func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { pathAddr := args[0].Pointer() @@ -290,7 +346,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opt return nil } var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return err } if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 { @@ -354,7 +410,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op return nil } var times [2]linux.Timespec - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil { return err } if times[0].Nsec != linux.UTIME_OMIT { @@ -382,7 +438,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error { root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root if !path.Absolute { if !path.HasComponents() && !bool(shouldAllowEmptyPath) { @@ -390,7 +446,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { @@ -401,13 +457,13 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa // VirtualFilesystem.SetStatAt(), since the former may be able // to use opened file state to expedite the SetStat. err := dirfile.SetStat(t, *opts) - dirfile.DecRef() + dirfile.DecRef(t) return err } start = dirfile.VirtualDentry() start.IncRef() - defer start.DecRef() - dirfile.DecRef() + defer start.DecRef(t) + dirfile.DecRef(t) } } return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{ diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go index 623992f6f..b89f34cdb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/signal.go +++ b/pkg/sentry/syscalls/linux/vfs2/signal.go @@ -45,7 +45,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Is this a signalfd? if sfd, ok := file.Impl().(*signalfd.SignalFileDescription); ok { @@ -68,7 +68,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) // Create a new descriptor. fd, err = t.NewFDFromVFS2(0, file, kernel.FDFlags{ diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..bfae6b7e9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -64,10 +66,10 @@ const flagsOffset = 48 const sizeOfInt32 = 4 // messageHeader64Len is the length of a MessageHeader64 struct. -var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) +var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes()) // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. -var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) +var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes()) // baseRecvFlags are the flags that are accepted across recvmsg(2), // recvmmsg(2), and recvfrom(2). @@ -75,6 +77,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | // MessageHeader64 is the 64-bit representation of the msghdr struct used in // the recvmsg and sendmsg syscalls. +// +// +marshal type MessageHeader64 struct { // Name is the optional pointer to a network address buffer. Name uint64 @@ -103,30 +107,14 @@ type MessageHeader64 struct { // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in // the recvmmsg and sendmmsg syscalls. +// +// +marshal type multipleMessageHeader64 struct { msgHdr MessageHeader64 msgLen uint32 _ int32 } -// CopyInMessageHeader64 copies a message header from user to kernel memory. -func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { - b := t.CopyScratchBuffer(52) - if _, err := t.CopyInBytes(addr, b); err != nil { - return err - } - - msg.Name = usermem.ByteOrder.Uint64(b[0:]) - msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) - msg.Iov = usermem.ByteOrder.Uint64(b[16:]) - msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) - msg.Control = usermem.ByteOrder.Uint64(b[32:]) - msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) - msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) - - return nil -} - // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { @@ -145,10 +133,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { // Get the buffer length. var bufLen uint32 - if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { return err } @@ -157,7 +145,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Write the length unconditionally. - if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil { return err } @@ -170,7 +158,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Copy as much of the address as will fit in the buffer. - encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + encodedAddr := t.CopyScratchBuffer(addr.SizeBytes()) + addr.MarshalUnsafe(encodedAddr) if bufLen > uint32(len(encodedAddr)) { bufLen = uint32(len(encodedAddr)) } @@ -194,7 +183,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if e != nil { return 0, nil, e.ToError() } - defer s.DecRef() + defer s.DecRef(t) if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil { return 0, nil, err @@ -228,8 +217,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } // Adding to the FD table will cause an extra reference to be acquired. - defer s1.DecRef() - defer s2.DecRef() + defer s1.DecRef(t) + defer s2.DecRef(t) nonblocking := uint32(stype & linux.SOCK_NONBLOCK) if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil { @@ -248,10 +237,10 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if _, file := t.FDTable().Remove(fd); file != nil { - file.DecRef() + if _, file := t.FDTable().Remove(t, fd); file != nil { + file.DecRef(t) } } return 0, nil, err @@ -271,7 +260,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -286,7 +275,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 - return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS) } // accept is the implementation of the accept syscall. It is called by accept @@ -302,7 +291,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -317,7 +306,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f peerRequested := addrLen != 0 nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } if peerRequested { // NOTE(magi): Linux does not give you an error if it can't @@ -361,7 +350,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -388,7 +377,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -417,7 +406,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -448,7 +437,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -457,8 +446,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Read the length. Reject negative values. - optLen := int32(0) - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + var optLen int32 + if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil { return 0, nil, err } if optLen < 0 { @@ -472,12 +461,12 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +488,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -527,7 +519,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -542,7 +534,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } @@ -565,7 +557,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -593,7 +585,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -626,7 +618,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -679,7 +671,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -731,7 +723,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -746,7 +738,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -769,16 +761,16 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { - return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS) } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Release() + cms.Release(t) } if int(msg.Flags) != mflags { // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -791,9 +783,9 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } - defer cms.Release() + defer cms.Release(t) controlData := make([]byte, 0, msg.ControlLen) controlData = control.PackControlMessages(t, cms, controlData) @@ -815,17 +807,17 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } // Copy the control data to the caller. - if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -849,7 +841,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -878,9 +870,9 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Release() + cm.Release(t) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } // Copy the address to the caller. @@ -922,7 +914,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -960,7 +952,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -994,7 +986,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -1009,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -1020,7 +1012,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -1062,9 +1054,9 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) - err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) + err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) if err != nil { - controlMessages.Release() + controlMessages.Release(t) } return uintptr(n), err } @@ -1082,7 +1074,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -1122,7 +1114,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) - return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) + return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file) } // SendTo implements the linux syscall sendto(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 945a364a7..bf5c1171f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -15,12 +15,18 @@ package vfs2 import ( + "io" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -50,12 +56,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) outFile := t.GetFileVFS2(outFD) if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) // Check that both files support the required directionality. if !inFile.IsReadable() || !outFile.IsWritable() { @@ -85,7 +91,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inFile.Options().DenyPRead { return 0, nil, syserror.EINVAL } - if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil { + if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil { return 0, nil, err } if inOffset < 0 { @@ -100,7 +106,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if outFile.Options().DenyPWrite { return 0, nil, syserror.EINVAL } - if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil { + if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil { return 0, nil, err } if outOffset < 0 { @@ -110,89 +116,67 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Move data. var ( - n int64 - err error - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { // If both input and output are pipes, delegate to the pipe - // implementation. Otherwise, exactly one end is a pipe, which we - // ensure is consistently ordered after the non-pipe FD's locks by - // passing the pipe FD as usermem.IO to the non-pipe end. + // implementation. Otherwise, exactly one end is a pipe, which + // we ensure is consistently ordered after the non-pipe FD's + // locks by passing the pipe FD as usermem.IO to the non-pipe + // end. switch { case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) case inIsPipe: + n, err = inPipeFD.SpliceToNonPipe(t, outFile, outOffset, count) if outOffset != -1 { - n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{}) outOffset += n - } else { - n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{}) } case outIsPipe: + n, err = outPipeFD.SpliceFromNonPipe(t, inFile, inOffset, count) if inOffset != -1 { - n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{}) inOffset += n - } else { - n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } + default: + panic("at least one end of splice must be a pipe") } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } - - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the splice operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. - } - if err = t.Block(inCh); err != nil { - break - } - } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + if err = dw.waitForBoth(t); err != nil { + break } } // Copy updated offsets out. if inOffsetPtr != 0 { - if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil { + if _, err := primitive.CopyInt64Out(t, inOffsetPtr, inOffset); err != nil { return 0, nil, err } } if outOffsetPtr != 0 { - if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil { + if _, err := primitive.CopyInt64Out(t, outOffsetPtr, outOffset); err != nil { return 0, nil, err } } - if n == 0 { - return 0, nil, err + if n != 0 { + // On Linux, inotify behavior is not very consistent with splice(2). We try + // our best to emulate Linux for very basic calls to splice, where for some + // reason, events are generated for output files, but not input files. + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } - // On Linux, inotify behavior is not very consistent with splice(2). We try - // our best to emulate Linux for very basic calls to splice, where for some - // reason, events are generated for output files, but not input files. - outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) - return uintptr(n), nil, nil + // We can only pass a single file to handleIOError, so pick inFile arbitrarily. + // This is used only for debugging purposes. + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile) } // Tee implements Linux syscall tee(2). @@ -219,12 +203,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) outFile := t.GetFileVFS2(outFD) if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) // Check that both files support the required directionality. if !inFile.IsReadable() || !outFile.IsWritable() { @@ -247,45 +231,274 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Copy data. var ( - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { - n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 { - return uintptr(n), nil, nil + n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + if err = dw.waitForBoth(t); err != nil { + break + } + } + + if n != 0 { + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) + + // If a partial write is completed, the error is dropped. Log it here. + if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { + log.Debugf("tee completed a partial write with error: %v", err) + err = nil + } + } + + // We can only pass a single file to handleIOError, so pick inFile arbitrarily. + // This is used only for debugging purposes. + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "tee", inFile) +} + +// Sendfile implements linux system call sendfile(2). +func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + outFD := args[0].Int() + inFD := args[1].Int() + offsetAddr := args[2].Pointer() + count := int64(args[3].SizeT()) + + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef(t) + if !inFile.IsReadable() { + return 0, nil, syserror.EBADF + } + + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef(t) + if !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // Verify that the outFile Append flag is not set. + if outFile.StatusFlags()&linux.O_APPEND != 0 { + return 0, nil, syserror.EINVAL + } + + // Verify that inFile is a regular file or block device. This is a + // requirement; the same check appears in Linux + // (fs/splice.c:splice_direct_to_actor). + if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil { + return 0, nil, err + } else if stat.Mask&linux.STATX_TYPE == 0 || + (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { + return 0, nil, syserror.EINVAL + } + + // Copy offset if it exists. + offset := int64(-1) + if offsetAddr != 0 { + if inFile.Options().DenyPRead { + return 0, nil, syserror.ESPIPE } - if err != syserror.ErrWouldBlock || nonBlock { + var offsetP primitive.Int64 + if _, err := offsetP.CopyIn(t, offsetAddr); err != nil { return 0, nil, err } + offset = int64(offsetP) - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the tee operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if offset+count < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Validate count. This must come after offset checks. + if count < 0 { + return 0, nil, syserror.EINVAL + } + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Copy data. + var ( + n int64 + err error + ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + // Reading from input file should never block, since it is regular or + // block device. We only need to check if writing to the output file + // can block. + nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 + if outIsPipe { + for n < count { + var spliceN int64 + spliceN, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count) + if offset != -1 { + offset += spliceN + } + n += spliceN + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) } - if err := t.Block(inCh); err != nil { - return 0, nil, err + if err != nil { + break } } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. + } else { + // Read inFile to buffer, then write the contents to outFile. + buf := make([]byte, count) + for n < count { + var readN int64 + if offset != -1 { + readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) + offset += readN + } else { + readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + } + n += readN + + // Write all of the bytes that we read. This may need + // multiple write calls to complete. + wbuf := buf[:readN] + for len(wbuf) > 0 { + var writeN int64 + writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) + wbuf = wbuf[writeN:] + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForOut(t) + } + if err != nil { + // We didn't complete the write. Only report the bytes that were actually + // written, and rewind offsets as needed. + notWritten := int64(len(wbuf)) + n -= notWritten + if offset == -1 { + // We modified the offset of the input file itself during the read + // operation. Rewind it. + if _, seekErr := inFile.Seek(t, -notWritten, linux.SEEK_CUR); seekErr != nil { + // Log the error but don't return it, since the write has already + // completed successfully. + log.Warningf("failed to roll back input file offset: %v", seekErr) + } + } else { + // The sendfile call was provided an offset parameter that should be + // adjusted to reflect the number of bytes sent. Rewind it. + offset -= notWritten + } + break + } } - if err := t.Block(outCh); err != nil { - return 0, nil, err + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) + } + if err != nil { + break } } } + + if offsetAddr != 0 { + // Copy out the new offset. + offsetP := primitive.Uint64(offset) + if _, err := offsetP.CopyOut(t, offsetAddr); err != nil { + return 0, nil, err + } + } + + if n != 0 { + inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) + + if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { + // If a partial write is completed, the error is dropped. Log it here. + log.Debugf("sendfile completed a partial write with error: %v", err) + err = nil + } + } + + // We can only pass a single file to handleIOError, so pick inFile arbitrarily. + // This is used only for debugging purposes. + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) +} + +// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not +// thread-safe, and does not take a reference on the vfs.FileDescriptions. +// +// Users must call destroy() when finished. +type dualWaiter struct { + inFile *vfs.FileDescription + outFile *vfs.FileDescription + + inW waiter.Entry + inCh chan struct{} + outW waiter.Entry + outCh chan struct{} +} + +// waitForBoth waits for both dw.inFile and dw.outFile to be ready. +func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { + if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if dw.inCh == nil { + dw.inW, dw.inCh = waiter.NewChannelEntry(nil) + dw.inFile.EventRegister(&dw.inW, eventMaskRead) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.inCh); err != nil { + return err + } + } + return dw.waitForOut(t) +} + +// waitForOut waits for dw.outfile to be read. +func (dw *dualWaiter) waitForOut(t *kernel.Task) error { + if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.outCh); err != nil { + return err + } + } + return nil +} + +// destroy cleans up resources help by dw. No more calls to wait* can occur +// after destroy is called. +func (dw *dualWaiter) destroy() { + if dw.inCh != nil { + dw.inFile.EventUnregister(&dw.inW) + dw.inCh = nil + } + if dw.outCh != nil { + dw.outFile.EventUnregister(&dw.outW) + dw.outCh = nil + } + dw.inFile = nil + dw.outFile = nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index bb1d5cac4..0f5d5189c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -65,7 +65,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags } root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root if !path.Absolute { if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { @@ -73,7 +73,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { @@ -85,7 +85,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags // former may be able to use opened file state to expedite the // Stat. statx, err := dirfile.Stat(t, opts) - dirfile.DecRef() + dirfile.DecRef(t) if err != nil { return err } @@ -96,8 +96,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags } start = dirfile.VirtualDentry() start.IncRef() - defer start.DecRef() - dirfile.DecRef() + defer start.DecRef(t) + dirfile.DecRef(t) } } @@ -132,7 +132,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) statx, err := file.Stat(t, vfs.StatOptions{ Mask: linux.STATX_BASIC_STATS, @@ -177,7 +177,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root if !path.Absolute { if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { @@ -185,7 +185,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { @@ -197,7 +197,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // former may be able to use opened file state to expedite the // Stat. statx, err := dirfile.Stat(t, opts) - dirfile.DecRef() + dirfile.DecRef(t) if err != nil { return 0, nil, err } @@ -207,8 +207,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } start = dirfile.VirtualDentry() start.IncRef() - defer start.DecRef() - dirfile.DecRef() + defer start.DecRef(t) + dirfile.DecRef(t) } } @@ -282,7 +282,7 @@ func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) err if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) // access(2) and faccessat(2) check permissions using real // UID/GID, not effective UID/GID. @@ -328,7 +328,7 @@ func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, siz if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop) if err != nil { @@ -358,7 +358,7 @@ func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) if err != nil { @@ -377,7 +377,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 0d0ebf46a..6e9b599e2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -34,7 +34,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, file.SyncFS(t) } @@ -47,7 +47,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, file.Sync(t) } @@ -77,7 +77,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support // is a full-file sync, i.e. fsync(2). As a result, there are severe @@ -108,7 +108,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 { if err := file.Sync(t); err != nil { - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go index 5ac79bc09..250870c03 100644 --- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go @@ -50,11 +50,11 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.EINVAL } vfsObj := t.Kernel().VFS() - file, err := timerfd.New(vfsObj, clock, fileFlags) + file, err := timerfd.New(t, vfsObj, clock, fileFlags) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.TFD_CLOEXEC != 0, }) @@ -79,7 +79,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) tfd, ok := file.Impl().(*timerfd.TimerFileDescription) if !ok { @@ -87,7 +87,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock()) @@ -97,7 +97,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, oldS := tfd.SetTime(newS) if oldValAddr != 0 { oldVal := ktime.ItimerspecFromSetting(tm, oldS) - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + if _, err := oldVal.CopyOut(t, oldValAddr); err != nil { return 0, nil, err } } @@ -113,7 +113,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) tfd, ok := file.Impl().(*timerfd.TimerFileDescription) if !ok { @@ -122,6 +122,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, s := tfd.GetTime() curVal := ktime.ItimerspecFromSetting(tm, s) - _, err := t.CopyOut(curValAddr, &curVal) + _, err := curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 8f497ecc7..c50fd97eb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -44,7 +44,7 @@ func Override() { s.Table[23] = syscalls.Supported("select", Select) s.Table[32] = syscalls.Supported("dup", Dup) s.Table[33] = syscalls.Supported("dup2", Dup2) - delete(s.Table, 40) // sendfile + s.Table[40] = syscalls.Supported("sendfile", Sendfile) s.Table[41] = syscalls.Supported("socket", Socket) s.Table[42] = syscalls.Supported("connect", Connect) s.Table[43] = syscalls.Supported("accept", Accept) @@ -62,7 +62,7 @@ func Override() { s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt) s.Table[59] = syscalls.Supported("execve", Execve) s.Table[72] = syscalls.Supported("fcntl", Fcntl) - s.Table[73] = syscalls.Supported("fcntl", Flock) + s.Table[73] = syscalls.Supported("flock", Flock) s.Table[74] = syscalls.Supported("fsync", Fsync) s.Table[75] = syscalls.Supported("fdatasync", Fdatasync) s.Table[76] = syscalls.Supported("truncate", Truncate) @@ -93,16 +93,16 @@ func Override() { s.Table[165] = syscalls.Supported("mount", Mount) s.Table[166] = syscalls.Supported("umount2", Umount2) s.Table[187] = syscalls.Supported("readahead", Readahead) - s.Table[188] = syscalls.Supported("setxattr", Setxattr) + s.Table[188] = syscalls.Supported("setxattr", SetXattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr) - s.Table[191] = syscalls.Supported("getxattr", Getxattr) + s.Table[191] = syscalls.Supported("getxattr", GetXattr) s.Table[192] = syscalls.Supported("lgetxattr", Lgetxattr) s.Table[193] = syscalls.Supported("fgetxattr", Fgetxattr) - s.Table[194] = syscalls.Supported("listxattr", Listxattr) + s.Table[194] = syscalls.Supported("listxattr", ListXattr) s.Table[195] = syscalls.Supported("llistxattr", Llistxattr) s.Table[196] = syscalls.Supported("flistxattr", Flistxattr) - s.Table[197] = syscalls.Supported("removexattr", Removexattr) + s.Table[197] = syscalls.Supported("removexattr", RemoveXattr) s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr) s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr) s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) @@ -163,6 +163,112 @@ func Override() { // Override ARM64. s = linux.ARM64 + s.Table[2] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) + s.Table[5] = syscalls.Supported("setxattr", SetXattr) + s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) + s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) + s.Table[8] = syscalls.Supported("getxattr", GetXattr) + s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr) + s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr) + s.Table[11] = syscalls.Supported("listxattr", ListXattr) + s.Table[12] = syscalls.Supported("llistxattr", Llistxattr) + s.Table[13] = syscalls.Supported("flistxattr", Flistxattr) + s.Table[14] = syscalls.Supported("removexattr", RemoveXattr) + s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr) + s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr) + s.Table[17] = syscalls.Supported("getcwd", Getcwd) + s.Table[19] = syscalls.Supported("eventfd2", Eventfd2) + s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1) + s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl) + s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait) + s.Table[23] = syscalls.Supported("dup", Dup) + s.Table[24] = syscalls.Supported("dup3", Dup3) + s.Table[25] = syscalls.Supported("fcntl", Fcntl) + s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil) + s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[29] = syscalls.Supported("ioctl", Ioctl) + s.Table[32] = syscalls.Supported("flock", Flock) + s.Table[33] = syscalls.Supported("mknodat", Mknodat) + s.Table[34] = syscalls.Supported("mkdirat", Mkdirat) + s.Table[35] = syscalls.Supported("unlinkat", Unlinkat) + s.Table[36] = syscalls.Supported("symlinkat", Symlinkat) + s.Table[37] = syscalls.Supported("linkat", Linkat) + s.Table[38] = syscalls.Supported("renameat", Renameat) + s.Table[39] = syscalls.Supported("umount2", Umount2) + s.Table[40] = syscalls.Supported("mount", Mount) + s.Table[43] = syscalls.Supported("statfs", Statfs) + s.Table[44] = syscalls.Supported("fstatfs", Fstatfs) + s.Table[45] = syscalls.Supported("truncate", Truncate) + s.Table[46] = syscalls.Supported("ftruncate", Ftruncate) + s.Table[47] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil) + s.Table[48] = syscalls.Supported("faccessat", Faccessat) + s.Table[49] = syscalls.Supported("chdir", Chdir) + s.Table[50] = syscalls.Supported("fchdir", Fchdir) + s.Table[51] = syscalls.Supported("chroot", Chroot) + s.Table[52] = syscalls.Supported("fchmod", Fchmod) + s.Table[53] = syscalls.Supported("fchmodat", Fchmodat) + s.Table[54] = syscalls.Supported("fchownat", Fchownat) + s.Table[55] = syscalls.Supported("fchown", Fchown) + s.Table[56] = syscalls.Supported("openat", Openat) + s.Table[57] = syscalls.Supported("close", Close) + s.Table[59] = syscalls.Supported("pipe2", Pipe2) + s.Table[61] = syscalls.Supported("getdents64", Getdents64) + s.Table[62] = syscalls.Supported("lseek", Lseek) s.Table[63] = syscalls.Supported("read", Read) + s.Table[64] = syscalls.Supported("write", Write) + s.Table[65] = syscalls.Supported("readv", Readv) + s.Table[66] = syscalls.Supported("writev", Writev) + s.Table[67] = syscalls.Supported("pread64", Pread64) + s.Table[68] = syscalls.Supported("pwrite64", Pwrite64) + s.Table[69] = syscalls.Supported("preadv", Preadv) + s.Table[70] = syscalls.Supported("pwritev", Pwritev) + s.Table[71] = syscalls.Supported("sendfile", Sendfile) + s.Table[72] = syscalls.Supported("pselect", Pselect) + s.Table[73] = syscalls.Supported("ppoll", Ppoll) + s.Table[74] = syscalls.Supported("signalfd4", Signalfd4) + s.Table[76] = syscalls.Supported("splice", Splice) + s.Table[77] = syscalls.Supported("tee", Tee) + s.Table[78] = syscalls.Supported("readlinkat", Readlinkat) + s.Table[79] = syscalls.Supported("newfstatat", Newfstatat) + s.Table[80] = syscalls.Supported("fstat", Fstat) + s.Table[81] = syscalls.Supported("sync", Sync) + s.Table[82] = syscalls.Supported("fsync", Fsync) + s.Table[83] = syscalls.Supported("fdatasync", Fdatasync) + s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange) + s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate) + s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime) + s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime) + s.Table[88] = syscalls.Supported("utimensat", Utimensat) + s.Table[198] = syscalls.Supported("socket", Socket) + s.Table[199] = syscalls.Supported("socketpair", SocketPair) + s.Table[200] = syscalls.Supported("bind", Bind) + s.Table[201] = syscalls.Supported("listen", Listen) + s.Table[202] = syscalls.Supported("accept", Accept) + s.Table[203] = syscalls.Supported("connect", Connect) + s.Table[204] = syscalls.Supported("getsockname", GetSockName) + s.Table[205] = syscalls.Supported("getpeername", GetPeerName) + s.Table[206] = syscalls.Supported("sendto", SendTo) + s.Table[207] = syscalls.Supported("recvfrom", RecvFrom) + s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt) + s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt) + s.Table[210] = syscalls.Supported("shutdown", Shutdown) + s.Table[211] = syscalls.Supported("sendmsg", SendMsg) + s.Table[212] = syscalls.Supported("recvmsg", RecvMsg) + s.Table[213] = syscalls.Supported("readahead", Readahead) + s.Table[221] = syscalls.Supported("execve", Execve) + s.Table[222] = syscalls.Supported("mmap", Mmap) + s.Table[223] = syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil) + s.Table[242] = syscalls.Supported("accept4", Accept4) + s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg) + s.Table[267] = syscalls.Supported("syncfs", Syncfs) + s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg) + s.Table[276] = syscalls.Supported("renameat2", Renameat2) + s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate) + s.Table[281] = syscalls.Supported("execveat", Execveat) + s.Table[286] = syscalls.Supported("preadv2", Preadv2) + s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) + s.Table[291] = syscalls.Supported("statx", Statx) + s.Init() } diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index af455d5c1..e05723ef9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -26,8 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Listxattr implements Linux syscall listxattr(2). -func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// ListXattr implements Linux syscall listxattr(2). +func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return listxattr(t, args, followFinalSymlink) } @@ -49,9 +49,9 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) - names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) + names, err := t.Kernel().VFS().ListXattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) if err != nil { return 0, nil, err } @@ -72,9 +72,9 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) - names, err := file.Listxattr(t, uint64(size)) + names, err := file.ListXattr(t, uint64(size)) if err != nil { return 0, nil, err } @@ -85,8 +85,8 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return uintptr(n), nil, nil } -// Getxattr implements Linux syscall getxattr(2). -func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// GetXattr implements Linux syscall getxattr(2). +func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return getxattr(t, args, followFinalSymlink) } @@ -109,14 +109,14 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) name, err := copyInXattrName(t, nameAddr) if err != nil { return 0, nil, err } - value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{ + value, err := t.Kernel().VFS().GetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetXattrOptions{ Name: name, Size: uint64(size), }) @@ -141,14 +141,14 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) name, err := copyInXattrName(t, nameAddr) if err != nil { return 0, nil, err } - value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)}) + value, err := file.GetXattr(t, &vfs.GetXattrOptions{Name: name, Size: uint64(size)}) if err != nil { return 0, nil, err } @@ -159,8 +159,8 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return uintptr(n), nil, nil } -// Setxattr implements Linux syscall setxattr(2). -func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// SetXattr implements Linux syscall setxattr(2). +func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, setxattr(t, args, followFinalSymlink) } @@ -188,7 +188,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -199,7 +199,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return err } - return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{ + return t.Kernel().VFS().SetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetXattrOptions{ Name: name, Value: value, Flags: uint32(flags), @@ -222,7 +222,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -233,15 +233,15 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{ + return 0, nil, file.SetXattr(t, &vfs.SetXattrOptions{ Name: name, Value: value, Flags: uint32(flags), }) } -// Removexattr implements Linux syscall removexattr(2). -func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// RemoveXattr implements Linux syscall removexattr(2). +func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, removexattr(t, args, followFinalSymlink) } @@ -262,14 +262,14 @@ func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSy if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) name, err := copyInXattrName(t, nameAddr) if err != nil { return err } - return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name) + return t.Kernel().VFS().RemoveXattrAt(t, t.Credentials(), &tpop.pop, name) } // Fremovexattr implements Linux syscall fremovexattr(2). @@ -281,14 +281,14 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) name, err := copyInXattrName(t, nameAddr) if err != nil { return 0, nil, err } - return 0, nil, file.Removexattr(t, name) + return 0, nil, file.RemoveXattr(t, name) } func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go index 65868cb26..cd1b95117 100644 --- a/pkg/sentry/time/parameters.go +++ b/pkg/sentry/time/parameters.go @@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par // // The log level is determined by the error severity. func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) { - fn := log.Debugf - if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() { + magNS := int64(errorNS.Magnitude()) + if magNS <= 10*time.Microsecond.Nanoseconds() { + // Don't log small errors. + return + } + fn := log.Infof + if magNS > time.Millisecond.Nanoseconds() { + // Upgrade large errors to warning. fn = log.Warningf - } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() { - fn = log.Infof } fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 642769e7c..c855608db 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -27,6 +27,39 @@ go_template_instance( }, ) +go_template_instance( + name = "file_description_refs", + out = "file_description_refs.go", + package = "vfs", + prefix = "FileDescription", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "FileDescription", + }, +) + +go_template_instance( + name = "mount_namespace_refs", + out = "mount_namespace_refs.go", + package = "vfs", + prefix = "MountNamespace", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "MountNamespace", + }, +) + +go_template_instance( + name = "filesystem_refs", + out = "filesystem_refs.go", + package = "vfs", + prefix = "Filesystem", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Filesystem", + }, +) + go_library( name = "vfs", srcs = [ @@ -40,12 +73,15 @@ go_library( "event_list.go", "file_description.go", "file_description_impl_util.go", + "file_description_refs.go", "filesystem.go", "filesystem_impl_util.go", + "filesystem_refs.go", "filesystem_type.go", "inotify.go", "lock.go", "mount.go", + "mount_namespace_refs.go", "mount_unsafe.go", "options.go", "pathname.go", @@ -56,13 +92,13 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", "//pkg/fspath", "//pkg/gohacks", "//pkg/log", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 4b9faf2ea..5aad31b78 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -184,12 +184,3 @@ This construction, which is essentially a type-safe analogue to Linux's - File locking - `O_ASYNC` - -- Reference counts in the `vfs` package do not use the `refs` package since - `refs.AtomicRefCount` adds 64 bytes of overhead to each 8-byte reference - count, resulting in considerable cache bloat. 24 bytes of this overhead is - for weak reference support, which have poor performance and will not be used - by VFS2. The remaining 40 bytes is to store a descriptive string and stack - trace for reference leak checking; we can support reference leak checking - without incurring this space overhead by including the applicable - information directly in finalizers for applicable types. diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 641e3e502..bdfd3ca8f 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -52,6 +52,8 @@ const ( ) // anonFilesystemType implements FilesystemType. +// +// +stateify savable type anonFilesystemType struct{} // GetFilesystem implements FilesystemType.GetFilesystem. @@ -69,12 +71,15 @@ func (anonFilesystemType) Name() string { // // Since all Dentries in anonFilesystem are non-directories, all FilesystemImpl // methods that would require an anonDentry to be a directory return ENOTDIR. +// +// +stateify savable type anonFilesystem struct { vfsfs Filesystem devMinor uint32 } +// +stateify savable type anonDentry struct { vfsd Dentry @@ -82,7 +87,7 @@ type anonDentry struct { } // Release implements FilesystemImpl.Release. -func (fs *anonFilesystem) Release() { +func (fs *anonFilesystem) Release(ctx context.Context) { } // Sync implements FilesystemImpl.Sync. @@ -245,32 +250,32 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements FilesystemImpl.ListxattrAt. -func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements FilesystemImpl.ListXattrAt. +func (fs *anonFilesystem) ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { if !rp.Done() { return nil, syserror.ENOTDIR } return nil, nil } -// GetxattrAt implements FilesystemImpl.GetxattrAt. -func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) { +// GetXattrAt implements FilesystemImpl.GetXattrAt. +func (fs *anonFilesystem) GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) { if !rp.Done() { return "", syserror.ENOTDIR } return "", syserror.ENOTSUP } -// SetxattrAt implements FilesystemImpl.SetxattrAt. -func (fs *anonFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { +// SetXattrAt implements FilesystemImpl.SetXattrAt. +func (fs *anonFilesystem) SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error { if !rp.Done() { return syserror.ENOTDIR } return syserror.EPERM } -// RemovexattrAt implements FilesystemImpl.RemovexattrAt. -func (fs *anonFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { +// RemoveXattrAt implements FilesystemImpl.RemoveXattrAt. +func (fs *anonFilesystem) RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error { if !rp.Done() { return syserror.ENOTDIR } @@ -294,7 +299,7 @@ func (d *anonDentry) TryIncRef() bool { } // DecRef implements DentryImpl.DecRef. -func (d *anonDentry) DecRef() { +func (d *anonDentry) DecRef(ctx context.Context) { // no-op } @@ -303,7 +308,7 @@ func (d *anonDentry) DecRef() { // Although Linux technically supports inotify on pseudo filesystems (inotify // is implemented at the vfs layer), it is not particularly useful. It is left // unimplemented until someone actually needs it. -func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {} +func (d *anonDentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) {} // Watches implements DentryImpl.Watches. func (d *anonDentry) Watches() *Watches { @@ -311,4 +316,4 @@ func (d *anonDentry) Watches() *Watches { } // OnZeroWatches implements Dentry.OnZeroWatches. -func (d *anonDentry) OnZeroWatches() {} +func (d *anonDentry) OnZeroWatches(context.Context) {} diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index c9e724fef..97018651f 100644 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go @@ -40,6 +40,30 @@ func MountNamespaceFromContext(ctx context.Context) *MountNamespace { return nil } +type mountNamespaceContext struct { + context.Context + mntns *MountNamespace +} + +// Value implements Context.Value. +func (mc mountNamespaceContext) Value(key interface{}) interface{} { + switch key { + case CtxMountNamespace: + mc.mntns.IncRef() + return mc.mntns + default: + return mc.Context.Value(key) + } +} + +// WithMountNamespace returns a copy of ctx with the given MountNamespace. +func WithMountNamespace(ctx context.Context, mntns *MountNamespace) context.Context { + return &mountNamespaceContext{ + Context: ctx, + mntns: mntns, + } +} + // RootFromContext returns the VFS root used by ctx. It takes a reference on // the returned VirtualDentry. If ctx does not have a specific VFS root, // RootFromContext returns a zero-value VirtualDentry. diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index cea3e6955..320ab7ce1 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -17,6 +17,7 @@ package vfs import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -88,6 +89,8 @@ func (d *Dentry) Impl() DentryImpl { // DentryImpl contains implementation details for a Dentry. Implementations of // DentryImpl should contain their associated Dentry by value as their first // field. +// +// +stateify savable type DentryImpl interface { // IncRef increments the Dentry's reference count. A Dentry with a non-zero // reference count must remain coherent with the state of the filesystem. @@ -102,7 +105,7 @@ type DentryImpl interface { TryIncRef() bool // DecRef decrements the Dentry's reference count. - DecRef() + DecRef(ctx context.Context) // InotifyWithParent notifies all watches on the targets represented by this // dentry and its parent. The parent's watches are notified first, followed @@ -113,7 +116,7 @@ type DentryImpl interface { // // Note that the events may not actually propagate up to the user, depending // on the event masks. - InotifyWithParent(events, cookie uint32, et EventType) + InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) // Watches returns the set of inotify watches for the file corresponding to // the Dentry. Dentries that are hard links to the same underlying file @@ -135,7 +138,7 @@ type DentryImpl interface { // The caller does not need to hold a reference on the dentry. OnZeroWatches // may acquire inotify locks, so to prevent deadlock, no inotify locks should // be held by the caller. - OnZeroWatches() + OnZeroWatches(ctx context.Context) } // IncRef increments d's reference count. @@ -150,8 +153,8 @@ func (d *Dentry) TryIncRef() bool { } // DecRef decrements d's reference count. -func (d *Dentry) DecRef() { - d.impl.DecRef() +func (d *Dentry) DecRef(ctx context.Context) { + d.impl.DecRef(ctx) } // IsDead returns true if d has been deleted or invalidated by its owning @@ -168,8 +171,8 @@ func (d *Dentry) isMounted() bool { // InotifyWithParent notifies all watches on the targets represented by d and // its parent of events. -func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) { - d.impl.InotifyWithParent(events, cookie, et) +func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) { + d.impl.InotifyWithParent(ctx, events, cookie, et) } // Watches returns the set of inotify watches associated with d. @@ -182,8 +185,8 @@ func (d *Dentry) Watches() *Watches { // OnZeroWatches performs cleanup tasks whenever the number of watches on a // dentry drops to zero. -func (d *Dentry) OnZeroWatches() { - d.impl.OnZeroWatches() +func (d *Dentry) OnZeroWatches(ctx context.Context) { + d.impl.OnZeroWatches(ctx) } // The following functions are exported so that filesystem implementations can @@ -214,11 +217,11 @@ func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) { // CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion // succeeds. -func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { +func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) { d.dead = true d.mu.Unlock() if d.isMounted() { - vfs.forgetDeadMountpoint(d) + vfs.forgetDeadMountpoint(ctx, d) } } @@ -226,12 +229,12 @@ func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { // did for reasons outside of VFS' control (e.g. d represents the local state // of a file on a remote filesystem on which the file has already been // deleted). -func (vfs *VirtualFilesystem) InvalidateDentry(d *Dentry) { +func (vfs *VirtualFilesystem) InvalidateDentry(ctx context.Context, d *Dentry) { d.mu.Lock() d.dead = true d.mu.Unlock() if d.isMounted() { - vfs.forgetDeadMountpoint(d) + vfs.forgetDeadMountpoint(ctx, d) } } @@ -241,8 +244,9 @@ func (vfs *VirtualFilesystem) InvalidateDentry(d *Dentry) { // caller must call AbortRenameDentry, CommitRenameReplaceDentry, or // CommitRenameExchangeDentry depending on the rename's outcome. // -// Preconditions: If to is not nil, it must be a child Dentry from the same -// Filesystem. from != to. +// Preconditions: +// * If to is not nil, it must be a child Dentry from the same Filesystem. +// * from != to. func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error { vfs.mountMu.Lock() if mntns.mountpoints[from] != 0 { @@ -278,13 +282,13 @@ func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { // that was replaced by from. // // Preconditions: PrepareRenameDentry was previously called on from and to. -func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, to *Dentry) { +func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) { from.mu.Unlock() if to != nil { to.dead = true to.mu.Unlock() if to.isMounted() { - vfs.forgetDeadMountpoint(to) + vfs.forgetDeadMountpoint(ctx, to) } } } @@ -303,7 +307,7 @@ func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) { // // forgetDeadMountpoint is analogous to Linux's // fs/namespace.c:__detach_mounts(). -func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) { +func (vfs *VirtualFilesystem) forgetDeadMountpoint(ctx context.Context, d *Dentry) { var ( vdsToDecRef []VirtualDentry mountsToDecRef []*Mount @@ -316,9 +320,9 @@ func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) { vfs.mounts.seq.EndWrite() vfs.mountMu.Unlock() for _, vd := range vdsToDecRef { - vd.DecRef() + vd.DecRef(ctx) } for _, mnt := range mountsToDecRef { - mnt.DecRef() + mnt.DecRef(ctx) } } diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go index 1e9dffc8f..dde2ad79b 100644 --- a/pkg/sentry/vfs/device.go +++ b/pkg/sentry/vfs/device.go @@ -22,6 +22,8 @@ import ( ) // DeviceKind indicates whether a device is a block or character device. +// +// +stateify savable type DeviceKind uint32 const ( @@ -44,6 +46,7 @@ func (kind DeviceKind) String() string { } } +// +stateify savable type devTuple struct { kind DeviceKind major uint32 diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 599c3131c..8f36c3e3b 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -27,6 +27,8 @@ import ( var epollCycleMu sync.Mutex // EpollInstance represents an epoll instance, as described by epoll(7). +// +// +stateify savable type EpollInstance struct { vfsfd FileDescription FileDescriptionDefaultImpl @@ -38,11 +40,11 @@ type EpollInstance struct { // interest is the set of file descriptors that are registered with the // EpollInstance for monitoring. interest is protected by interestMu. - interestMu sync.Mutex + interestMu sync.Mutex `state:"nosave"` interest map[epollInterestKey]*epollInterest // mu protects fields in registered epollInterests. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // ready is the set of file descriptors that may be "ready" for I/O. Note // that this must be an ordered list, not a map: "If more than maxevents @@ -55,6 +57,7 @@ type EpollInstance struct { ready epollInterestList } +// +stateify savable type epollInterestKey struct { // file is the registered FileDescription. No reference is held on file; // instead, when the last reference is dropped, FileDescription.DecRef() @@ -67,6 +70,8 @@ type epollInterestKey struct { } // epollInterest represents an EpollInstance's interest in a file descriptor. +// +// +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. epoll *EpollInstance @@ -93,9 +98,9 @@ type epollInterest struct { // NewEpollInstanceFD returns a FileDescription representing a new epoll // instance. A reference is taken on the returned FileDescription. -func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) { +func (vfs *VirtualFilesystem) NewEpollInstanceFD(ctx context.Context) (*FileDescription, error) { vd := vfs.NewAnonVirtualDentry("[eventpoll]") - defer vd.DecRef() + defer vd.DecRef(ctx) ep := &EpollInstance{ interest: make(map[epollInterestKey]*epollInterest), } @@ -110,7 +115,7 @@ func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) { } // Release implements FileDescriptionImpl.Release. -func (ep *EpollInstance) Release() { +func (ep *EpollInstance) Release(ctx context.Context) { // Unregister all polled fds. ep.interestMu.Lock() defer ep.interestMu.Unlock() @@ -186,7 +191,7 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin } // Register interest in file. - mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP + mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP epi := &epollInterest{ epoll: ep, key: key, @@ -257,7 +262,7 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event } // Update epi for the next call to ep.ReadEvents(). - mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP + mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP ep.mu.Lock() epi.mask = mask epi.userData = event.Data @@ -331,11 +336,9 @@ func (ep *EpollInstance) removeLocked(epi *epollInterest) { ep.mu.Unlock() } -// ReadEvents reads up to len(events) ready events into events and returns the -// number of events read. -// -// Preconditions: len(events) != 0. -func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int { +// ReadEvents appends up to maxReady events to events and returns the updated +// slice of events. +func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent, maxEvents int) []linux.EpollEvent { i := 0 // Hot path: avoid defer. ep.mu.Lock() @@ -368,16 +371,16 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int { requeue.PushBack(epi) } // Report ievents. - events[i] = linux.EpollEvent{ + events = append(events, linux.EpollEvent{ Events: ievents.ToLinux(), Data: epi.userData, - } + }) i++ - if i == len(events) { + if i == maxEvents { break } } ep.ready.PushBackList(&requeue) ep.mu.Unlock() - return i + return events } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 0c42574db..1eba0270f 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -37,13 +37,13 @@ import ( // FileDescription methods require that a reference is held. // // FileDescription is analogous to Linux's struct file. +// +// +stateify savable type FileDescription struct { - // refs is the reference count. refs is accessed using atomic memory - // operations. - refs int64 + FileDescriptionRefs // flagsMu protects statusFlags and asyncHandler below. - flagsMu sync.Mutex + flagsMu sync.Mutex `state:"nosave"` // statusFlags contains status flags, "initialized by open(2) and possibly // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic @@ -58,7 +58,7 @@ type FileDescription struct { // epolls is the set of epollInterests registered for this FileDescription. // epolls is protected by epollMu. - epollMu sync.Mutex + epollMu sync.Mutex `state:"nosave"` epolls map[*epollInterest]struct{} // vd is the filesystem location at which this FileDescription was opened. @@ -90,6 +90,8 @@ type FileDescription struct { } // FileDescriptionOptions contains options to FileDescription.Init(). +// +// +stateify savable type FileDescriptionOptions struct { // If AllowDirectIO is true, allow O_DIRECT to be set on the file. AllowDirectIO bool @@ -103,7 +105,7 @@ type FileDescriptionOptions struct { // If UseDentryMetadata is true, calls to FileDescription methods that // interact with file and filesystem metadata (Stat, SetStat, StatFS, - // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling + // ListXattr, GetXattr, SetXattr, RemoveXattr) are implemented by calling // the corresponding FilesystemImpl methods instead of the corresponding // FileDescriptionImpl methods. // @@ -131,7 +133,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou } } - fd.refs = 1 + fd.EnableLeakCheck() // Remove "file creation flags" to mirror the behavior from file.f_flags in // fs/open.c:do_dentry_open. @@ -149,30 +151,9 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou return nil } -// IncRef increments fd's reference count. -func (fd *FileDescription) IncRef() { - atomic.AddInt64(&fd.refs, 1) -} - -// TryIncRef increments fd's reference count and returns true. If fd's -// reference count is already zero, TryIncRef does nothing and returns false. -// -// TryIncRef does not require that a reference is held on fd. -func (fd *FileDescription) TryIncRef() bool { - for { - refs := atomic.LoadInt64(&fd.refs) - if refs <= 0 { - return false - } - if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) { - return true - } - } -} - // DecRef decrements fd's reference count. -func (fd *FileDescription) DecRef() { - if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 { +func (fd *FileDescription) DecRef(ctx context.Context) { + fd.FileDescriptionRefs.DecRef(func() { // Unregister fd from all epoll instances. fd.epollMu.Lock() epolls := fd.epolls @@ -196,11 +177,11 @@ func (fd *FileDescription) DecRef() { } // Release implementation resources. - fd.impl.Release() + fd.impl.Release(ctx) if fd.writable { fd.vd.mount.EndWrite() } - fd.vd.DecRef() + fd.vd.DecRef(ctx) fd.flagsMu.Lock() // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { @@ -208,15 +189,7 @@ func (fd *FileDescription) DecRef() { } fd.asyncHandler = nil fd.flagsMu.Unlock() - } else if refs < 0 { - panic("FileDescription.DecRef() called without holding a reference") - } -} - -// Refs returns the current number of references. The returned count -// is inherently racy and is unsafe to use without external synchronization. -func (fd *FileDescription) Refs() int64 { - return atomic.LoadInt64(&fd.refs) + }) } // Mount returns the mount on which fd was opened. It does not take a reference @@ -289,7 +262,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { return syserror.EINVAL } - // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? + // TODO(gvisor.dev/issue/1035): FileDescriptionImpl.SetOAsync()? const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK fd.flagsMu.Lock() if fd.asyncHandler != nil { @@ -301,7 +274,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede fd.asyncHandler.Unregister(fd) } } - fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags) + atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) fd.flagsMu.Unlock() return nil } @@ -335,7 +308,7 @@ func (fd *FileDescription) Impl() FileDescriptionImpl { type FileDescriptionImpl interface { // Release is called when the associated FileDescription reaches zero // references. - Release() + Release(ctx context.Context) // OnClose is called when a file descriptor representing the // FileDescription is closed. Note that returning a non-nil error does not @@ -354,8 +327,13 @@ type FileDescriptionImpl interface { // represented by the FileDescription. StatFS(ctx context.Context) (linux.Statfs, error) - // Allocate grows file represented by FileDescription to offset + length bytes. + // Allocate grows the file to offset + length bytes. // Only mode == 0 is supported currently. + // + // Allocate should return EISDIR on directories, ESPIPE on pipes, and ENODEV on + // other files where it is not supported. + // + // Preconditions: The FileDescription was opened for writing. Allocate(ctx context.Context, mode, offset, length uint64) error // waiter.Waitable methods may be used to poll for I/O events. @@ -369,8 +347,9 @@ type FileDescriptionImpl interface { // // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP. // - // Preconditions: The FileDescription was opened for reading. - // FileDescriptionOptions.DenyPRead == false. + // Preconditions: + // * The FileDescription was opened for reading. + // * FileDescriptionOptions.DenyPRead == false. PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) // Read is similar to PRead, but does not specify an offset. @@ -401,8 +380,9 @@ type FileDescriptionImpl interface { // - If opts.Flags specifies unsupported options, PWrite returns // EOPNOTSUPP. // - // Preconditions: The FileDescription was opened for writing. - // FileDescriptionOptions.DenyPWrite == false. + // Preconditions: + // * The FileDescription was opened for writing. + // * FileDescriptionOptions.DenyPWrite == false. PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) // Write is similar to PWrite, but does not specify an offset, which is @@ -447,19 +427,19 @@ type FileDescriptionImpl interface { // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // Listxattr returns all extended attribute names for the file. - Listxattr(ctx context.Context, size uint64) ([]string, error) + // ListXattr returns all extended attribute names for the file. + ListXattr(ctx context.Context, size uint64) ([]string, error) - // Getxattr returns the value associated with the given extended attribute + // GetXattr returns the value associated with the given extended attribute // for the file. - Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) + GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) - // Setxattr changes the value associated with the given extended attribute + // SetXattr changes the value associated with the given extended attribute // for the file. - Setxattr(ctx context.Context, opts SetxattrOptions) error + SetXattr(ctx context.Context, opts SetXattrOptions) error - // Removexattr removes the given extended attribute from the file. - Removexattr(ctx context.Context, name string) error + // RemoveXattr removes the given extended attribute from the file. + RemoveXattr(ctx context.Context, name string) error // LockBSD tries to acquire a BSD-style advisory file lock. LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error @@ -475,6 +455,8 @@ type FileDescriptionImpl interface { } // Dirent holds the information contained in struct linux_dirent64. +// +// +stateify savable type Dirent struct { // Name is the filename. Name string @@ -526,7 +508,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St Start: fd.vd, }) stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return stat, err } return fd.impl.Stat(ctx, opts) @@ -541,7 +523,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return err } return fd.impl.SetStat(ctx, opts) @@ -557,12 +539,20 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vd, }) statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return statfs, err } return fd.impl.StatFS(ctx) } +// Allocate grows file represented by FileDescription to offset + length bytes. +func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { + if !fd.IsWritable() { + return syserror.EBADF + } + return fd.impl.Allocate(ctx, mode, offset, length) +} + // Readiness implements waiter.Waitable.Readiness. // // It returns fd's I/O readiness. @@ -654,25 +644,25 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. return fd.impl.Ioctl(ctx, uio, args) } -// Listxattr returns all extended attribute names for the file represented by +// ListXattr returns all extended attribute names for the file represented by // fd. // // If the size of the list (including a NUL terminating byte after every entry) // would exceed size, ERANGE may be returned. Note that implementations // are free to ignore size entirely and return without error). In all cases, // if size is 0, the list should be returned without error, regardless of size. -func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { +func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size) - vfsObj.putResolvingPath(rp) + names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) + vfsObj.putResolvingPath(ctx, rp) return names, err } - names, err := fd.impl.Listxattr(ctx, size) + names, err := fd.impl.ListXattr(ctx, size) if err == syserror.ENOTSUP { // Linux doesn't actually return ENOTSUP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security @@ -683,57 +673,57 @@ func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string return names, err } -// Getxattr returns the value associated with the given extended attribute for +// GetXattr returns the value associated with the given extended attribute for // the file represented by fd. // // If the size of the return value exceeds opts.Size, ERANGE may be returned // (note that implementations are free to ignore opts.Size entirely and return // without error). In all cases, if opts.Size is 0, the value should be // returned without error, regardless of size. -func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) { +func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) (string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(rp) + val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) + vfsObj.putResolvingPath(ctx, rp) return val, err } - return fd.impl.Getxattr(ctx, *opts) + return fd.impl.GetXattr(ctx, *opts) } -// Setxattr changes the value associated with the given extended attribute for +// SetXattr changes the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error { +func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(rp) + err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) + vfsObj.putResolvingPath(ctx, rp) return err } - return fd.impl.Setxattr(ctx, *opts) + return fd.impl.SetXattr(ctx, *opts) } -// Removexattr removes the given extended attribute from the file represented +// RemoveXattr removes the given extended attribute from the file represented // by fd. -func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { +func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name) - vfsObj.putResolvingPath(rp) + err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) + vfsObj.putResolvingPath(ctx, rp) return err } - return fd.impl.Removexattr(ctx, name) + return fd.impl.RemoveXattr(ctx, name) } // SyncFS instructs the filesystem containing fd to execute the semantics of @@ -747,7 +737,7 @@ func (fd *FileDescription) MappedName(ctx context.Context) string { vfsroot := RootFromContext(ctx) s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd) if vfsroot.Ok() { - vfsroot.DecRef() + vfsroot.DecRef(ctx) } return s } @@ -835,3 +825,31 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn } return fd.asyncHandler } + +// FileReadWriteSeeker is a helper struct to pass a FileDescription as +// io.Reader/io.Writer/io.ReadSeeker/etc. +type FileReadWriteSeeker struct { + FD *FileDescription + Ctx context.Context + ROpts ReadOptions + WOpts WriteOptions +} + +// Read implements io.ReadWriteSeeker.Read. +func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { + dst := usermem.BytesIOSequence(p) + ret, err := f.FD.Read(f.Ctx, dst, f.ROpts) + return int(ret), err +} + +// Seek implements io.ReadWriteSeeker.Seek. +func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { + return f.FD.Seek(f.Ctx, offset, int32(whence)) +} + +// Write implements io.ReadWriteSeeker.Write. +func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { + buf := usermem.BytesIOSequence(p) + ret, err := f.FD.Write(f.Ctx, buf, f.WOpts) + return int(ret), err +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 6b8b4ad49..48ca9de44 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -42,6 +42,8 @@ import ( // FileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl to obtain implementations of many FileDescriptionImpl // methods with default behavior analogous to Linux's. +// +// +stateify savable type FileDescriptionDefaultImpl struct{} // OnClose implements FileDescriptionImpl.OnClose analogously to @@ -57,7 +59,11 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err } // Allocate implements FileDescriptionImpl.Allocate analogously to -// fallocate called on regular file, directory or FIFO in Linux. +// fallocate called on an invalid type of file in Linux. +// +// Note that directories can rely on this implementation even though they +// should technically return EISDIR. Allocate should never be called for a +// directory, because it requires a writable fd. func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { return syserror.ENODEV } @@ -134,34 +140,36 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } -// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// ListXattr implements FileDescriptionImpl.ListXattr analogously to // inode_operations::listxattr == NULL in Linux. -func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) { - // This isn't exactly accurate; see FileDescription.Listxattr. +func (FileDescriptionDefaultImpl) ListXattr(ctx context.Context, size uint64) ([]string, error) { + // This isn't exactly accurate; see FileDescription.ListXattr. return nil, syserror.ENOTSUP } -// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// GetXattr implements FileDescriptionImpl.GetXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) { +func (FileDescriptionDefaultImpl) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { return "", syserror.ENOTSUP } -// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// SetXattr implements FileDescriptionImpl.SetXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { +func (FileDescriptionDefaultImpl) SetXattr(ctx context.Context, opts SetXattrOptions) error { return syserror.ENOTSUP } -// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// RemoveXattr implements FileDescriptionImpl.RemoveXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { +func (FileDescriptionDefaultImpl) RemoveXattr(ctx context.Context, name string) error { return syserror.ENOTSUP } // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. +// +// +stateify savable type DirectoryFileDescriptionDefaultImpl struct{} // Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate. @@ -192,6 +200,8 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme // DentryMetadataFileDescriptionImpl may be embedded by implementations of // FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is // true to obtain implementations of Stat and SetStat that panic. +// +// +stateify savable type DentryMetadataFileDescriptionImpl struct{} // Stat implements FileDescriptionImpl.Stat. @@ -206,12 +216,16 @@ func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetSt // DynamicBytesSource represents a data source for a // DynamicBytesFileDescriptionImpl. +// +// +stateify savable type DynamicBytesSource interface { // Generate writes the file's contents to buf. Generate(ctx context.Context, buf *bytes.Buffer) error } // StaticData implements DynamicBytesSource over a static string. +// +// +stateify savable type StaticData struct { Data string } @@ -238,14 +252,24 @@ type WritableDynamicBytesSource interface { // // DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first // use. +// +// +stateify savable type DynamicBytesFileDescriptionImpl struct { data DynamicBytesSource // immutable - mu sync.Mutex // protects the following fields - buf bytes.Buffer + mu sync.Mutex `state:"nosave"` // protects the following fields + buf bytes.Buffer `state:".([]byte)"` off int64 lastRead int64 // offset at which the last Read, PRead, or Seek ended } +func (fd *DynamicBytesFileDescriptionImpl) saveBuf() []byte { + return fd.buf.Bytes() +} + +func (fd *DynamicBytesFileDescriptionImpl) loadBuf(p []byte) { + fd.buf.Write(p) +} + // SetDataSource must be called exactly once on fd before first use. func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) { fd.data = data @@ -378,6 +402,8 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M // LockFD may be used by most implementations of FileDescriptionImpl.Lock* // functions. Caller must call Init(). +// +// +stateify savable type LockFD struct { locks *FileLocks } @@ -405,6 +431,8 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { // NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface // returning ENOLCK. +// +// +stateify savable type NoLockFD struct{} // LockBSD implements vfs.FileDescriptionImpl.LockBSD. diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 3b7e1c273..1cd607c0a 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -80,9 +80,9 @@ type testFD struct { data DynamicBytesSource } -func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription { +func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription { vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef() + defer vd.DecRef(ctx) var fd testFD fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}) fd.DynamicBytesFileDescriptionImpl.SetDataSource(data) @@ -90,7 +90,7 @@ func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesS } // Release implements FileDescriptionImpl.Release. -func (fd *testFD) Release() { +func (fd *testFD) Release(context.Context) { } // SetStatusFlags implements FileDescriptionImpl.SetStatusFlags. @@ -109,11 +109,11 @@ func TestGenCountFD(t *testing.T) { ctx := contexttest.Context(t) vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } - fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{}) - defer fd.DecRef() + fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &genCount{}) + defer fd.DecRef(ctx) // The first read causes Generate to be called to fill the FD's buffer. buf := make([]byte, 2) @@ -167,11 +167,11 @@ func TestWritable(t *testing.T) { ctx := contexttest.Context(t) vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } - fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"}) - defer fd.DecRef() + fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &storeData{data: "init"}) + defer fd.DecRef(ctx) buf := make([]byte, 10) ioseq := usermem.BytesIOSequence(buf) diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 6bb9ca180..c93d94634 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -15,8 +15,6 @@ package vfs import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" @@ -34,9 +32,7 @@ import ( // // +stateify savable type Filesystem struct { - // refs is the reference count. refs is accessed using atomic memory - // operations. - refs int64 + FilesystemRefs // vfs is the VirtualFilesystem that uses this Filesystem. vfs is // immutable. @@ -52,7 +48,7 @@ type Filesystem struct { // Init must be called before first use of fs. func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) { - fs.refs = 1 + fs.EnableLeakCheck() fs.vfs = vfsObj fs.fsType = fsType fs.impl = impl @@ -76,39 +72,14 @@ func (fs *Filesystem) Impl() FilesystemImpl { return fs.impl } -// IncRef increments fs' reference count. -func (fs *Filesystem) IncRef() { - if atomic.AddInt64(&fs.refs, 1) <= 1 { - panic("Filesystem.IncRef() called without holding a reference") - } -} - -// TryIncRef increments fs' reference count and returns true. If fs' reference -// count is zero, TryIncRef does nothing and returns false. -// -// TryIncRef does not require that a reference is held on fs. -func (fs *Filesystem) TryIncRef() bool { - for { - refs := atomic.LoadInt64(&fs.refs) - if refs <= 0 { - return false - } - if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) { - return true - } - } -} - // DecRef decrements fs' reference count. -func (fs *Filesystem) DecRef() { - if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { +func (fs *Filesystem) DecRef(ctx context.Context) { + fs.FilesystemRefs.DecRef(func() { fs.vfs.filesystemsMu.Lock() delete(fs.vfs.filesystems, fs) fs.vfs.filesystemsMu.Unlock() - fs.impl.Release() - } else if refs < 0 { - panic("Filesystem.decRef() called without holding a reference") - } + fs.impl.Release(ctx) + }) } // FilesystemImpl contains implementation details for a Filesystem. @@ -149,7 +120,7 @@ func (fs *Filesystem) DecRef() { type FilesystemImpl interface { // Release is called when the associated Filesystem reaches zero // references. - Release() + Release(ctx context.Context) // Sync "causes all pending modifications to filesystem metadata and cached // file data to be written to the underlying [filesystem]", as by syncfs(2). @@ -212,8 +183,9 @@ type FilesystemImpl interface { // ENOENT. Equivalently, if vd represents a file with a link count of 0 not // created by open(O_TMPFILE) without O_EXCL, LinkAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If LinkAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -231,8 +203,9 @@ type FilesystemImpl interface { // - If the directory in which the new directory would be created has been // removed by RmdirAt or RenameAt, MkdirAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If MkdirAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -253,8 +226,9 @@ type FilesystemImpl interface { // - If the directory in which the file would be created has been removed // by RmdirAt or RenameAt, MknodAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If MknodAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -345,11 +319,12 @@ type FilesystemImpl interface { // - If renaming would replace a non-empty directory, RenameAt returns // ENOTEMPTY. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). oldParentVD.Dentry() was obtained from a - // previous call to - // oldParentVD.Mount().Filesystem().Impl().GetParentDentryAt(). oldName is - // not "." or "..". + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). + // * oldParentVD.Dentry() was obtained from a previous call to + // oldParentVD.Mount().Filesystem().Impl().GetParentDentryAt(). + // * oldName is not "." or "..". // // Postconditions: If RenameAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -372,8 +347,9 @@ type FilesystemImpl interface { // - If the file at rp exists but is not a directory, RmdirAt returns // ENOTDIR. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If RmdirAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -410,8 +386,9 @@ type FilesystemImpl interface { // - If the directory in which the symbolic link would be created has been // removed by RmdirAt or RenameAt, SymlinkAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If SymlinkAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -431,33 +408,34 @@ type FilesystemImpl interface { // // - If the file at rp exists but is a directory, UnlinkAt returns EISDIR. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If UnlinkAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). UnlinkAt(ctx context.Context, rp *ResolvingPath) error - // ListxattrAt returns all extended attribute names for the file at rp. + // ListXattrAt returns all extended attribute names for the file at rp. // // Errors: // // - If extended attributes are not supported by the filesystem, - // ListxattrAt returns ENOTSUP. + // ListXattrAt returns ENOTSUP. // // - If the size of the list (including a NUL terminating byte after every // entry) would exceed size, ERANGE may be returned. Note that // implementations are free to ignore size entirely and return without // error). In all cases, if size is 0, the list should be returned without // error, regardless of size. - ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) + ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) - // GetxattrAt returns the value associated with the given extended + // GetXattrAt returns the value associated with the given extended // attribute for the file at rp. // // Errors: // - // - If extended attributes are not supported by the filesystem, GetxattrAt + // - If extended attributes are not supported by the filesystem, GetXattrAt // returns ENOTSUP. // // - If an extended attribute named opts.Name does not exist, ENODATA is @@ -467,30 +445,30 @@ type FilesystemImpl interface { // returned (note that implementations are free to ignore opts.Size entirely // and return without error). In all cases, if opts.Size is 0, the value // should be returned without error, regardless of size. - GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) + GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) - // SetxattrAt changes the value associated with the given extended + // SetXattrAt changes the value associated with the given extended // attribute for the file at rp. // // Errors: // - // - If extended attributes are not supported by the filesystem, SetxattrAt + // - If extended attributes are not supported by the filesystem, SetXattrAt // returns ENOTSUP. // // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists, // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist, // ENODATA is returned. - SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error - // RemovexattrAt removes the given extended attribute from the file at rp. + // RemoveXattrAt removes the given extended attribute from the file at rp. // // Errors: // // - If extended attributes are not supported by the filesystem, - // RemovexattrAt returns ENOTSUP. + // RemoveXattrAt returns ENOTSUP. // // - If name does not exist, ENODATA is returned. - RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error // BoundEndpointAt returns the Unix socket endpoint bound at the path rp. // @@ -528,6 +506,8 @@ type FilesystemImpl interface { // PrependPathAtVFSRootError is returned by implementations of // FilesystemImpl.PrependPath() when they encounter the contextual VFS root. +// +// +stateify savable type PrependPathAtVFSRootError struct{} // Error implements error.Error. @@ -538,6 +518,8 @@ func (PrependPathAtVFSRootError) Error() string { // PrependPathAtNonMountRootError is returned by implementations of // FilesystemImpl.PrependPath() when they encounter an independent ancestor // Dentry that is not the Mount root. +// +// +stateify savable type PrependPathAtNonMountRootError struct{} // Error implements error.Error. @@ -548,6 +530,8 @@ func (PrependPathAtNonMountRootError) Error() string { // PrependPathSyntheticError is returned by implementations of // FilesystemImpl.PrependPath() for which prepended names do not represent real // paths. +// +// +stateify savable type PrependPathSyntheticError struct{} // Error implements error.Error. diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go index 465e610e0..2620cf975 100644 --- a/pkg/sentry/vfs/filesystem_impl_util.go +++ b/pkg/sentry/vfs/filesystem_impl_util.go @@ -16,6 +16,9 @@ package vfs import ( "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/usermem" ) // GenericParseMountOptions parses a comma-separated list of options of the @@ -41,3 +44,13 @@ func GenericParseMountOptions(str string) map[string]string { } return m } + +// GenericStatFS returns a statfs struct filled with the common fields for a +// general filesystem. This is analogous to Linux's fs/libfs.cs:simple_statfs(). +func GenericStatFS(fsMagic uint64) linux.Statfs { + return linux.Statfs{ + Type: fsMagic, + BlockSize: usermem.PageSize, + NameLength: linux.NAME_MAX, + } +} diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index f2298f7f6..bc19db1d5 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -55,10 +55,13 @@ type registeredFilesystemType struct { // RegisterFilesystemTypeOptions contains options to // VirtualFilesystem.RegisterFilesystem(). +// +// +stateify savable type RegisterFilesystemTypeOptions struct { - // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt() - // for which MountOptions.InternalMount == false to use this filesystem - // type. + // AllowUserMount determines whether users are allowed to mount a file system + // of this type, i.e. through mount(2). If AllowUserMount is true, allow calls + // to VirtualFilesystem.MountAt() for which MountOptions.InternalMount == false + // to use this filesystem type. AllowUserMount bool // If AllowUserList is true, make this filesystem type visible in diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md index e7da49faa..833db213f 100644 --- a/pkg/sentry/vfs/g3doc/inotify.md +++ b/pkg/sentry/vfs/g3doc/inotify.md @@ -28,9 +28,9 @@ The set of all watches held on a single file (i.e., the watch target) is stored in vfs.Watches. Each watch will belong to a different inotify instance (an instance can only have one watch on any watch target). The watches are stored in a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions -to a single file will all share the same vfs.Watches. Activity on the target -causes its vfs.Watches to generate notifications on its watches’ inotify -instances. +to a single file will all share the same vfs.Watches (with the exception of the +gofer filesystem, described in a later section). Activity on the target causes +its vfs.Watches to generate notifications on its watches’ inotify instances. ### vfs.Watch @@ -103,12 +103,12 @@ inotify: unopened p9 file (and possibly an open FID), through which the Sentry interacts with the gofer. * *Solution:* Because there is no inode structure stored in the sandbox, - inotify watches must be held on the dentry. This would be an issue in - the presence of hard links, where multiple dentries would need to share - the same set of watches, but in VFS2, we do not support the internal - creation of hard links on gofer fs. As a result, we make the assumption - that every dentry corresponds to a unique inode. However, the next point - raises an issue with this assumption: + inotify watches must be held on the dentry. For the purposes of inotify, + we assume that every dentry corresponds to a unique inode, which may + cause unexpected behavior in the presence of hard links, where multiple + dentries should share the same set of watches. Indeed, it is impossible + for us to be absolutely sure whether dentries correspond to the same + file or not, due to the following point: * **The Sentry cannot always be aware of hard links on the remote filesystem.** There is no way for us to confirm whether two files on the remote filesystem are actually links to the same inode. QIDs and inodes are diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 8882fa84a..2d27d9d35 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -27,6 +27,8 @@ import ( ) // Dentry is a required type parameter that is a struct with the given fields. +// +// +stateify savable type Dentry struct { // vfsd is the embedded vfs.Dentry corresponding to this vfs.DentryImpl. vfsd vfs.Dentry diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index c2e21ac5f..3f0b8f45b 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -37,6 +37,8 @@ const inotifyEventBaseSize = 16 // // The way events are labelled appears somewhat arbitrary, but they must match // Linux so that IN_EXCL_UNLINK behaves as it does in Linux. +// +// +stateify savable type EventType uint8 // PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and @@ -100,7 +102,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) id := uniqueid.GlobalFromContext(ctx) vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id)) - defer vd.DecRef() + defer vd.DecRef(ctx) fd := &Inotify{ id: id, scratch: make([]byte, inotifyEventBaseSize), @@ -118,7 +120,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) // Release implements FileDescriptionImpl.Release. Release removes all // watches and frees all resources for an inotify instance. -func (i *Inotify) Release() { +func (i *Inotify) Release(ctx context.Context) { var ds []*Dentry // We need to hold i.mu to avoid a race with concurrent calls to @@ -144,7 +146,7 @@ func (i *Inotify) Release() { i.mu.Unlock() for _, d := range ds { - d.OnZeroWatches() + d.OnZeroWatches(ctx) } } @@ -179,12 +181,12 @@ func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask { return mask & ready } -// PRead implements FileDescriptionImpl. +// PRead implements FileDescriptionImpl.PRead. func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.ESPIPE } -// PWrite implements FileDescriptionImpl. +// PWrite implements FileDescriptionImpl.PWrite. func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { return 0, syserror.ESPIPE } @@ -243,7 +245,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt return writeLen, nil } -// Ioctl implements fs.FileOperations.Ioctl. +// Ioctl implements FileDescriptionImpl.Ioctl. func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch args[1].Int() { case linux.FIONREAD: @@ -350,7 +352,7 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) { // RmWatch looks up an inotify watch for the given 'wd' and configures the // target to stop sending events to this inotify instance. -func (i *Inotify) RmWatch(wd int32) error { +func (i *Inotify) RmWatch(ctx context.Context, wd int32) error { i.mu.Lock() // Find the watch we were asked to removed. @@ -374,7 +376,7 @@ func (i *Inotify) RmWatch(wd int32) error { i.mu.Unlock() if remaining == 0 { - w.target.OnZeroWatches() + w.target.OnZeroWatches(ctx) } // Generate the event for the removal. @@ -462,7 +464,7 @@ func (w *Watches) Remove(id uint64) { // Notify queues a new event with watches in this set. Watches with // IN_EXCL_UNLINK are skipped if the event is coming from a child that has been // unlinked. -func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) { +func (w *Watches) Notify(ctx context.Context, name string, events, cookie uint32, et EventType, unlinked bool) { var hasExpired bool w.mu.RLock() for _, watch := range w.ws { @@ -476,13 +478,13 @@ func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlin w.mu.RUnlock() if hasExpired { - w.cleanupExpiredWatches() + w.cleanupExpiredWatches(ctx) } } // This function is relatively expensive and should only be called where there // are expired watches. -func (w *Watches) cleanupExpiredWatches() { +func (w *Watches) cleanupExpiredWatches(ctx context.Context) { // Because of lock ordering, we cannot acquire Inotify.mu for each watch // owner while holding w.mu. As a result, store expired watches locally // before removing. @@ -495,15 +497,15 @@ func (w *Watches) cleanupExpiredWatches() { } w.mu.RUnlock() for _, watch := range toRemove { - watch.owner.RmWatch(watch.wd) + watch.owner.RmWatch(ctx, watch.wd) } } // HandleDeletion is called when the watch target is destroyed. Clear the // watch set, detach watches from the inotify instances they belong to, and // generate the appropriate events. -func (w *Watches) HandleDeletion() { - w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */) +func (w *Watches) HandleDeletion(ctx context.Context) { + w.Notify(ctx, "", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */) // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for // the owner of each watch being deleted. Instead, atomically store the @@ -744,12 +746,12 @@ func InotifyEventFromStatMask(mask uint32) uint32 { // InotifyRemoveChild sends the appriopriate notifications to the watch sets of // the child being removed and its parent. Note that unlike most pairs of // parent/child notifications, the child is notified first in this case. -func InotifyRemoveChild(self, parent *Watches, name string) { +func InotifyRemoveChild(ctx context.Context, self, parent *Watches, name string) { if self != nil { - self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */) + self.Notify(ctx, "", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */) } if parent != nil { - parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */) + parent.Notify(ctx, name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */) } } @@ -762,13 +764,13 @@ func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, } cookie := uniqueid.InotifyCookie(ctx) if oldParent != nil { - oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */) + oldParent.Notify(ctx, oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */) } if newParent != nil { - newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */) + newParent.Notify(ctx, newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */) } // Somewhat surprisingly, self move events do not have a cookie. if renamed != nil { - renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */) + renamed.Notify(ctx, "", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */) } } diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 6c7583a81..55783d4eb 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -33,6 +33,8 @@ import ( // Note that in Linux these two types of locks are _not_ cooperative, because // race and deadlock conditions make merging them prohibitive. We do the same // and keep them oblivious to each other. +// +// +stateify savable type FileLocks struct { // bsd is a set of BSD-style advisory file wide locks, see flock(2). bsd fslock.Locks @@ -46,7 +48,13 @@ func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fsloc if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { return nil } - return syserror.ErrWouldBlock + + // Return an appropriate error for the unsuccessful lock attempt, depending on + // whether this is a blocking or non-blocking operation. + if block == nil { + return syserror.ErrWouldBlock + } + return syserror.ERESTARTSYS } // UnlockBSD releases a BSD-style lock on the entire file. @@ -66,7 +74,13 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fsl if fl.posix.LockRegion(uid, t, rng, block) { return nil } - return syserror.ErrWouldBlock + + // Return an appropriate error for the unsuccessful lock attempt, depending on + // whether this is a blocking or non-blocking operation. + if block == nil { + return syserror.ErrWouldBlock + } + return syserror.ERESTARTSYS } // UnlockPOSIX releases a POSIX-style lock on a file region. diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go index cc1e7d764..638b5d830 100644 --- a/pkg/sentry/vfs/memxattr/xattr.go +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -33,8 +33,8 @@ type SimpleExtendedAttributes struct { xattrs map[string]string } -// Getxattr returns the value at 'name'. -func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) { +// GetXattr returns the value at 'name'. +func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) { x.mu.RLock() value, ok := x.xattrs[opts.Name] x.mu.RUnlock() @@ -49,8 +49,8 @@ func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, return value, nil } -// Setxattr sets 'value' at 'name'. -func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { +// SetXattr sets 'value' at 'name'. +func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { x.mu.Lock() defer x.mu.Unlock() if x.xattrs == nil { @@ -72,8 +72,8 @@ func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { return nil } -// Listxattr returns all names in xattrs. -func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { +// ListXattr returns all names in xattrs. +func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { // Keep track of the size of the buffer needed in listxattr(2) for the list. listSize := 0 x.mu.RLock() @@ -90,8 +90,8 @@ func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { return names, nil } -// Removexattr removes the xattr at 'name'. -func (x *SimpleExtendedAttributes) Removexattr(name string) error { +// RemoveXattr removes the xattr at 'name'. +func (x *SimpleExtendedAttributes) RemoveXattr(name string) error { x.mu.Lock() defer x.mu.Unlock() if _, ok := x.xattrs[name]; !ok { diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 32f901bd8..dfc3ae6c0 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -65,7 +65,7 @@ type Mount struct { // // Invariant: key.parent != nil iff key.point != nil. key.point belongs to // key.parent.fs. - key mountKey + key mountKey `state:".(VirtualDentry)"` // ns is the namespace in which this Mount was mounted. ns is protected by // VirtualFilesystem.mountMu. @@ -114,7 +114,7 @@ func (mnt *Mount) Options() MountOptions { defer mnt.vfs.mountMu.Unlock() return MountOptions{ Flags: mnt.Flags, - ReadOnly: mnt.readOnly(), + ReadOnly: mnt.ReadOnly(), } } @@ -126,16 +126,14 @@ func (mnt *Mount) Options() MountOptions { // // +stateify savable type MountNamespace struct { + MountNamespaceRefs + // Owner is the usernamespace that owns this mount namespace. Owner *auth.UserNamespace // root is the MountNamespace's root mount. root is immutable. root *Mount - // refs is the reference count. refs is accessed using atomic memory - // operations. - refs int64 - // mountpoints maps all Dentries which are mount points in this namespace // to the number of Mounts for which they are mount points. mountpoints is // protected by VirtualFilesystem.mountMu. @@ -154,22 +152,22 @@ type MountNamespace struct { // NewMountNamespace returns a new mount namespace with a root filesystem // configured by the given arguments. A reference is taken on the returned // MountNamespace. -func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { +func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *MountOptions) (*MountNamespace, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { ctx.Warningf("Unknown filesystem type: %s", fsTypeName) return nil, syserror.ENODEV } - fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts) + fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return nil, err } mntns := &MountNamespace{ Owner: creds.UserNamespace, - refs: 1, mountpoints: make(map[*Dentry]uint32), } - mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{}) + mntns.EnableLeakCheck() + mntns.root = newMount(vfs, fs, root, mntns, opts) return mntns, nil } @@ -200,8 +198,8 @@ func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth if err != nil { return nil, err } - defer root.DecRef() - defer fs.DecRef() + defer root.DecRef(ctx) + defer fs.DecRef(ctx) return vfs.NewDisconnectedMount(fs, root, opts) } @@ -221,7 +219,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr if vd.dentry.dead { vd.dentry.mu.Unlock() vfs.mountMu.Unlock() - vd.DecRef() + vd.DecRef(ctx) return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -243,7 +241,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr // This can't fail since we're holding vfs.mountMu. nextmnt.root.IncRef() vd.dentry.mu.Unlock() - vd.DecRef() + vd.DecRef(ctx) vd = VirtualDentry{ mount: nextmnt, dentry: nextmnt.root, @@ -263,16 +261,20 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr } // MountAt creates and mounts a Filesystem configured by the given arguments. -func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { +// The VirtualFilesystem will hold a reference to the Mount until it is unmounted. +// +// This method returns the mounted Mount without a reference, for convenience +// during VFS setup when there is no chance of racing with unmount. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) (*Mount, error) { mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts) if err != nil { - return err + return nil, err } - defer mnt.DecRef() + defer mnt.DecRef(ctx) if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil { - return err + return nil, err } - return nil + return mnt, nil } // UmountAt removes the Mount at the given path. @@ -293,13 +295,13 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti if err != nil { return err } - defer vd.DecRef() + defer vd.DecRef(ctx) if vd.dentry != vd.mount.root { return syserror.EINVAL } vfs.mountMu.Lock() if mntns := MountNamespaceFromContext(ctx); mntns != nil { - defer mntns.DecRef() + defer mntns.DecRef(ctx) if mntns != vd.mount.ns { vfs.mountMu.Unlock() return syserror.EINVAL @@ -335,14 +337,15 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti vfs.mounts.seq.EndWrite() vfs.mountMu.Unlock() for _, vd := range vdsToDecRef { - vd.DecRef() + vd.DecRef(ctx) } for _, mnt := range mountsToDecRef { - mnt.DecRef() + mnt.DecRef(ctx) } return nil } +// +stateify savable type umountRecursiveOptions struct { // If eager is true, ensure that future calls to Mount.tryIncMountedRef() // on umounted mounts fail. @@ -369,8 +372,9 @@ type umountRecursiveOptions struct { // // umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree(). // -// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. +// Preconditions: +// * vfs.mountMu must be locked. +// * vfs.mounts.seq must be in a writer critical section. func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) { if !mnt.umounted { mnt.umounted = true @@ -399,9 +403,11 @@ func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecu // connectLocked makes vd the mount parent/point for mnt. It consumes // references held by vd. // -// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt -// must not already be connected. +// Preconditions: +// * vfs.mountMu must be locked. +// * vfs.mounts.seq must be in a writer critical section. +// * d.mu must be locked. +// * mnt.parent() == nil, i.e. mnt must not already be connected. func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { if checkInvariants { if mnt.parent() != nil { @@ -409,7 +415,7 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns } } mnt.IncRef() // dropped by callers of umountRecursiveLocked - mnt.storeKey(vd) + mnt.setKey(vd) if vd.mount.children == nil { vd.mount.children = make(map[*Mount]struct{}) } @@ -429,16 +435,18 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns // disconnectLocked makes vd have no mount parent/point and returns its old // mount parent/point with a reference held. // -// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. mnt.parent() != nil. +// Preconditions: +// * vfs.mountMu must be locked. +// * vfs.mounts.seq must be in a writer critical section. +// * mnt.parent() != nil. func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { - vd := mnt.loadKey() + vd := mnt.getKey() if checkInvariants { if vd.mount != nil { panic("VFS.disconnectLocked called on disconnected mount") } } - mnt.storeKey(VirtualDentry{}) + mnt.loadKey(VirtualDentry{}) delete(vd.mount.children, mnt) atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 mnt.ns.mountpoints[vd.dentry]-- @@ -479,7 +487,7 @@ func (mnt *Mount) IncRef() { } // DecRef decrements mnt's reference count. -func (mnt *Mount) DecRef() { +func (mnt *Mount) DecRef(ctx context.Context) { refs := atomic.AddInt64(&mnt.refs, -1) if refs&^math.MinInt64 == 0 { // mask out MSB var vd VirtualDentry @@ -490,25 +498,18 @@ func (mnt *Mount) DecRef() { mnt.vfs.mounts.seq.EndWrite() mnt.vfs.mountMu.Unlock() } - mnt.root.DecRef() - mnt.fs.DecRef() + mnt.root.DecRef(ctx) + mnt.fs.DecRef(ctx) if vd.Ok() { - vd.DecRef() + vd.DecRef(ctx) } } } -// IncRef increments mntns' reference count. -func (mntns *MountNamespace) IncRef() { - if atomic.AddInt64(&mntns.refs, 1) <= 1 { - panic("MountNamespace.IncRef() called without holding a reference") - } -} - // DecRef decrements mntns' reference count. -func (mntns *MountNamespace) DecRef() { +func (mntns *MountNamespace) DecRef(ctx context.Context) { vfs := mntns.root.fs.VirtualFilesystem() - if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 { + mntns.MountNamespaceRefs.DecRef(func() { vfs.mountMu.Lock() vfs.mounts.seq.BeginWrite() vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{ @@ -517,14 +518,12 @@ func (mntns *MountNamespace) DecRef() { vfs.mounts.seq.EndWrite() vfs.mountMu.Unlock() for _, vd := range vdsToDecRef { - vd.DecRef() + vd.DecRef(ctx) } for _, mnt := range mountsToDecRef { - mnt.DecRef() + mnt.DecRef(ctx) } - } else if refs < 0 { - panic("MountNamespace.DecRef() called without holding a reference") - } + }) } // getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes @@ -534,7 +533,7 @@ func (mntns *MountNamespace) DecRef() { // getMountAt is analogous to Linux's fs/namei.c:follow_mount(). // // Preconditions: References are held on mnt and d. -func (vfs *VirtualFilesystem) getMountAt(mnt *Mount, d *Dentry) *Mount { +func (vfs *VirtualFilesystem) getMountAt(ctx context.Context, mnt *Mount, d *Dentry) *Mount { // The first mount is special-cased: // // - The caller is assumed to have checked d.isMounted() already. (This @@ -565,7 +564,7 @@ retryFirst: // Raced with umount. continue } - mnt.DecRef() + mnt.DecRef(ctx) mnt = next d = next.root } @@ -576,9 +575,10 @@ retryFirst: // mnt. It takes a reference on the returned VirtualDentry. If no such mount // point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil). // -// Preconditions: References are held on mnt and root. vfsroot is not (mnt, -// mnt.root). -func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry { +// Preconditions: +// * References are held on mnt and root. +// * vfsroot is not (mnt, mnt.root). +func (vfs *VirtualFilesystem) getMountpointAt(ctx context.Context, mnt *Mount, vfsroot VirtualDentry) VirtualDentry { // The first mount is special-cased: // // - The caller must have already checked mnt against vfsroot. @@ -602,12 +602,12 @@ retryFirst: if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can only // happen due to a racing change to Mount.key. - parent.DecRef() + parent.DecRef(ctx) goto retryFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.DecRef() - parent.DecRef() + point.DecRef(ctx) + parent.DecRef(ctx) goto retryFirst } mnt = parent @@ -635,22 +635,29 @@ retryFirst: if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can // only happen due to a racing change to Mount.key. - parent.DecRef() + parent.DecRef(ctx) goto retryNotFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.DecRef() - parent.DecRef() + point.DecRef(ctx) + parent.DecRef(ctx) goto retryNotFirst } - d.DecRef() - mnt.DecRef() + d.DecRef(ctx) + mnt.DecRef(ctx) mnt = parent d = point } return VirtualDentry{mnt, d} } +// SetMountReadOnly sets the mount as ReadOnly. +func (vfs *VirtualFilesystem) SetMountReadOnly(mnt *Mount, ro bool) error { + vfs.mountMu.Lock() + defer vfs.mountMu.Unlock() + return mnt.setReadOnlyLocked(ro) +} + // CheckBeginWrite increments the counter of in-progress write operations on // mnt. If mnt is mounted MS_RDONLY, CheckBeginWrite does nothing and returns // EROFS. @@ -688,7 +695,8 @@ func (mnt *Mount) setReadOnlyLocked(ro bool) error { return nil } -func (mnt *Mount) readOnly() bool { +// ReadOnly returns true if mount is readonly. +func (mnt *Mount) ReadOnly() bool { return atomic.LoadInt64(&mnt.writers) < 0 } @@ -731,11 +739,23 @@ func (mntns *MountNamespace) Root() VirtualDentry { // // Preconditions: taskRootDir.Ok(). func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) { - vfs.mountMu.Lock() - defer vfs.mountMu.Unlock() rootMnt := taskRootDir.mount + + vfs.mountMu.Lock() mounts := rootMnt.submountsLocked() + // Take a reference on mounts since we need to drop vfs.mountMu before + // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()). + for _, mnt := range mounts { + mnt.IncRef() + } + vfs.mountMu.Unlock() + defer func() { + for _, mnt := range mounts { + mnt.DecRef(ctx) + } + }() sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID }) + for _, mnt := range mounts { // Get the path to this mount relative to task root. mntRootVD := VirtualDentry{ @@ -746,7 +766,7 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if err != nil { // For some reason we didn't get a path. Log a warning // and run with empty path. - ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err) + ctx.Warningf("VFS.GenerateProcMounts: error getting pathname for mount root %+v: %v", mnt.root, err) path = "" } if path == "" { @@ -756,7 +776,7 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi } opts := "rw" - if mnt.readOnly() { + if mnt.ReadOnly() { opts = "ro" } if mnt.Flags.NoATime { @@ -780,11 +800,25 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi // // Preconditions: taskRootDir.Ok(). func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) { - vfs.mountMu.Lock() - defer vfs.mountMu.Unlock() rootMnt := taskRootDir.mount + + vfs.mountMu.Lock() mounts := rootMnt.submountsLocked() + // Take a reference on mounts since we need to drop vfs.mountMu before + // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()) or + // vfs.StatAt() (=> FilesystemImpl.StatAt()). + for _, mnt := range mounts { + mnt.IncRef() + } + vfs.mountMu.Unlock() + defer func() { + for _, mnt := range mounts { + mnt.DecRef(ctx) + } + }() sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID }) + + creds := auth.CredentialsFromContext(ctx) for _, mnt := range mounts { // Get the path to this mount relative to task root. mntRootVD := VirtualDentry{ @@ -795,7 +829,7 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo if err != nil { // For some reason we didn't get a path. Log a warning // and run with empty path. - ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err) + ctx.Warningf("VFS.GenerateProcMountInfo: error getting pathname for mount root %+v: %v", mnt.root, err) path = "" } if path == "" { @@ -808,9 +842,10 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo Root: mntRootVD, Start: mntRootVD, } - statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{}) + statx, err := vfs.StatAt(ctx, creds, pop, &StatOptions{}) if err != nil { // Well that's not good. Ignore this mount. + ctx.Warningf("VFS.GenerateProcMountInfo: failed to stat mount root %+v: %v", mnt.root, err) break } @@ -822,6 +857,9 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo fmt.Fprintf(buf, "%d ", mnt.ID) // (2) Parent ID (or this ID if there is no parent). + // Note that even if the call to mnt.parent() races with Mount + // destruction (which is possible since we're not holding vfs.mountMu), + // its Mount.ID will still be valid. pID := mnt.ID if p := mnt.parent(); p != nil { pID = p.ID @@ -844,7 +882,7 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo // (6) Mount options. opts := "rw" - if mnt.readOnly() { + if mnt.ReadOnly() { opts = "ro" } if mnt.Flags.NoATime { @@ -883,7 +921,7 @@ func superBlockOpts(mountPath string, mnt *Mount) string { // gVisor doesn't (yet) have a concept of super block options, so we // use the ro/rw bit from the mount flag. opts := "rw" - if mnt.readOnly() { + if mnt.ReadOnly() { opts = "ro" } diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index 3335e4057..cb8c56bd3 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -38,7 +38,7 @@ func TestMountTableInsertLookup(t *testing.T) { mt.Init() mount := &Mount{} - mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) + mount.setKey(VirtualDentry{&Mount{}, &Dentry{}}) mt.Insert(mount) if m := mt.Lookup(mount.parent(), mount.point()); m != mount { @@ -79,7 +79,7 @@ const enableComparativeBenchmarks = false func newBenchMount() *Mount { mount := &Mount{} - mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) + mount.loadKey(VirtualDentry{&Mount{}, &Dentry{}}) return mount } @@ -94,7 +94,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { for i := 0; i < numMounts; i++ { mount := newBenchMount() mt.Insert(mount) - keys = append(keys, mount.loadKey()) + keys = append(keys, mount.saveKey()) } var ready sync.WaitGroup @@ -146,7 +146,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := mount.loadKey() + key := mount.saveKey() ms[key] = mount keys = append(keys, key) } @@ -201,7 +201,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := mount.loadKey() + key := mount.getKey() ms.Store(key, mount) keys = append(keys, key) } @@ -283,7 +283,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms[mount.loadKey()] = mount + ms[mount.getKey()] = mount } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -318,7 +318,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { var ms sync.Map for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms.Store(mount.loadKey(), mount) + ms.Store(mount.saveKey(), mount) } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -372,7 +372,7 @@ func BenchmarkMountMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms[mount.loadKey()] = mount + ms[mount.saveKey()] = mount } } @@ -392,7 +392,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Store(mount.loadKey(), mount) + ms.Store(mount.saveKey(), mount) } } @@ -425,13 +425,13 @@ func BenchmarkMountMapRemove(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := range mounts { mount := mounts[i] - ms[mount.loadKey()] = mount + ms[mount.saveKey()] = mount } b.ResetTimer() for i := range mounts { mount := mounts[i] - delete(ms, mount.loadKey()) + delete(ms, mount.saveKey()) } } @@ -447,12 +447,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) { var ms sync.Map for i := range mounts { mount := mounts[i] - ms.Store(mount.loadKey(), mount) + ms.Store(mount.saveKey(), mount) } b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Delete(mount.loadKey()) + ms.Delete(mount.saveKey()) } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index 70f850ca4..b7d122d22 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. @@ -34,6 +34,8 @@ import ( // structurally identical to VirtualDentry, but stores its fields as // unsafe.Pointer since mutators synchronize with VFS path traversal using // seqcounts. +// +// This is explicitly not savable. type mountKey struct { parent unsafe.Pointer // *Mount point unsafe.Pointer // *Dentry @@ -47,19 +49,23 @@ func (mnt *Mount) point() *Dentry { return (*Dentry)(atomic.LoadPointer(&mnt.key.point)) } -func (mnt *Mount) loadKey() VirtualDentry { +func (mnt *Mount) getKey() VirtualDentry { return VirtualDentry{ mount: mnt.parent(), dentry: mnt.point(), } } +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // Invariant: mnt.key.parent == nil. vd.Ok(). -func (mnt *Mount) storeKey(vd VirtualDentry) { +func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). @@ -92,6 +98,7 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 + // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } @@ -217,8 +224,9 @@ func (mt *mountTable) Insert(mount *Mount) { // insertSeqed inserts the given mount into mt. // -// Preconditions: mt.seq must be in a writer critical section. mt must not -// already contain a Mount with the same mount point and parent. +// Preconditions: +// * mt.seq must be in a writer critical section. +// * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) @@ -269,9 +277,11 @@ func (mt *mountTable) insertSeqed(mount *Mount) { atomic.StorePointer(&mt.slots, newSlots) } -// Preconditions: There are no concurrent mutators of the table (slots, cap). -// If the table is visible to readers, then mt.seq must be in a writer critical -// section. cap must be a power of 2. +// Preconditions: +// * There are no concurrent mutators of the table (slots, cap). +// * If the table is visible to readers, then mt.seq must be in a writer +// critical section. +// * cap must be a power of 2. func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, hash uintptr) { mask := cap - 1 off := (hash & mask) * mountSlotBytes @@ -313,8 +323,9 @@ func (mt *mountTable) Remove(mount *Mount) { // removeSeqed removes the given mount from mt. // -// Preconditions: mt.seq must be in a writer critical section. mt must contain -// mount. +// Preconditions: +// * mt.seq must be in a writer critical section. +// * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) tcap := uintptr(1) << (mt.size & mtSizeOrderMask) diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index f223aeda8..bc79e5ecc 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -21,6 +21,8 @@ import ( // GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and // FilesystemImpl.GetDentryAt(). +// +// +stateify savable type GetDentryOptions struct { // If CheckSearchable is true, FilesystemImpl.GetDentryAt() must check that // the returned Dentry is a directory for which creds has search @@ -30,6 +32,8 @@ type GetDentryOptions struct { // MkdirOptions contains options to VirtualFilesystem.MkdirAt() and // FilesystemImpl.MkdirAt(). +// +// +stateify savable type MkdirOptions struct { // Mode is the file mode bits for the created directory. Mode linux.FileMode @@ -56,6 +60,8 @@ type MkdirOptions struct { // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). +// +// +stateify savable type MknodOptions struct { // Mode is the file type and mode bits for the created file. Mode linux.FileMode @@ -72,6 +78,8 @@ type MknodOptions struct { // MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC. // MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers. +// +// +stateify savable type MountFlags struct { // NoExec is equivalent to MS_NOEXEC. NoExec bool @@ -79,9 +87,22 @@ type MountFlags struct { // NoATime is equivalent to MS_NOATIME and indicates that the // filesystem should not update access time in-place. NoATime bool + + // NoDev is equivalent to MS_NODEV and indicates that the + // filesystem should not allow access to devices (special files). + // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE + // filesystems. + NoDev bool + + // NoSUID is equivalent to MS_NOSUID and indicates that the + // filesystem should not honor set-user-ID and set-group-ID bits or + // file capabilities when executing programs. + NoSUID bool } // MountOptions contains options to VirtualFilesystem.MountAt(). +// +// +stateify savable type MountOptions struct { // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC. Flags MountFlags @@ -92,13 +113,17 @@ type MountOptions struct { // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). GetFilesystemOptions GetFilesystemOptions - // If InternalMount is true, allow the use of filesystem types for which - // RegisterFilesystemTypeOptions.AllowUserMount == false. + // InternalMount indicates whether the mount operation is coming from the + // application, i.e. through mount(2). If InternalMount is true, allow the use + // of filesystem types for which RegisterFilesystemTypeOptions.AllowUserMount + // == false. InternalMount bool } // OpenOptions contains options to VirtualFilesystem.OpenAt() and // FilesystemImpl.OpenAt(). +// +// +stateify savable type OpenOptions struct { // Flags contains access mode and flags as specified for open(2). // @@ -124,6 +149,8 @@ type OpenOptions struct { // ReadOptions contains options to FileDescription.PRead(), // FileDescriptionImpl.PRead(), FileDescription.Read(), and // FileDescriptionImpl.Read(). +// +// +stateify savable type ReadOptions struct { // Flags contains flags as specified for preadv2(2). Flags uint32 @@ -131,6 +158,8 @@ type ReadOptions struct { // RenameOptions contains options to VirtualFilesystem.RenameAt() and // FilesystemImpl.RenameAt(). +// +// +stateify savable type RenameOptions struct { // Flags contains flags as specified for renameat2(2). Flags uint32 @@ -142,6 +171,8 @@ type RenameOptions struct { // SetStatOptions contains options to VirtualFilesystem.SetStatAt(), // FilesystemImpl.SetStatAt(), FileDescription.SetStat(), and // FileDescriptionImpl.SetStat(). +// +// +stateify savable type SetStatOptions struct { // Stat is the metadata that should be set. Only fields indicated by // Stat.Mask should be set. @@ -153,10 +184,18 @@ type SetStatOptions struct { // == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask // instead). Stat linux.Statx + + // NeedWritePerm indicates that write permission on the file is needed for + // this operation. This is needed for truncate(2) (note that ftruncate(2) + // does not require the same check--instead, it checks that the fd is + // writable). + NeedWritePerm bool } // BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt() // and FilesystemImpl.BoundEndpointAt(). +// +// +stateify savable type BoundEndpointOptions struct { // Addr is the path of the file whose socket endpoint is being retrieved. // It is generally irrelevant: most endpoints are stored at a dentry that @@ -173,10 +212,12 @@ type BoundEndpointOptions struct { Addr string } -// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(), -// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and -// FileDescriptionImpl.Getxattr(). -type GetxattrOptions struct { +// GetXattrOptions contains options to VirtualFilesystem.GetXattrAt(), +// FilesystemImpl.GetXattrAt(), FileDescription.GetXattr(), and +// FileDescriptionImpl.GetXattr(). +// +// +stateify savable +type GetXattrOptions struct { // Name is the name of the extended attribute to retrieve. Name string @@ -187,10 +228,12 @@ type GetxattrOptions struct { Size uint64 } -// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), -// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and -// FileDescriptionImpl.Setxattr(). -type SetxattrOptions struct { +// SetXattrOptions contains options to VirtualFilesystem.SetXattrAt(), +// FilesystemImpl.SetXattrAt(), FileDescription.SetXattr(), and +// FileDescriptionImpl.SetXattr(). +// +// +stateify savable +type SetXattrOptions struct { // Name is the name of the extended attribute being mutated. Name string @@ -204,6 +247,8 @@ type SetxattrOptions struct { // StatOptions contains options to VirtualFilesystem.StatAt(), // FilesystemImpl.StatAt(), FileDescription.Stat(), and // FileDescriptionImpl.Stat(). +// +// +stateify savable type StatOptions struct { // Mask is the set of fields in the returned Statx that the FilesystemImpl // or FileDescriptionImpl should provide. Bits are as in linux.Statx.Mask. @@ -221,6 +266,8 @@ type StatOptions struct { } // UmountOptions contains options to VirtualFilesystem.UmountAt(). +// +// +stateify savable type UmountOptions struct { // Flags contains flags as specified for umount2(2). Flags uint32 @@ -229,6 +276,8 @@ type UmountOptions struct { // WriteOptions contains options to FileDescription.PWrite(), // FileDescriptionImpl.PWrite(), FileDescription.Write(), and // FileDescriptionImpl.Write(). +// +// +stateify savable type WriteOptions struct { // Flags contains flags as specified for pwritev2(2). Flags uint32 diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index cd78d66bc..e4da15009 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -47,7 +47,7 @@ func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, haveRef := false defer func() { if haveRef { - vd.DecRef() + vd.DecRef(ctx) } }() @@ -64,12 +64,12 @@ loop: // of FilesystemImpl.PrependPath() may return nil instead. break loop } - nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot) if !nextVD.Ok() { break loop } if haveRef { - vd.DecRef() + vd.DecRef(ctx) } vd = nextVD haveRef = true @@ -101,7 +101,7 @@ func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd haveRef := false defer func() { if haveRef { - vd.DecRef() + vd.DecRef(ctx) } }() loop: @@ -112,12 +112,12 @@ loop: if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { break loop } - nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot) if !nextVD.Ok() { return "", nil } if haveRef { - vd.DecRef() + vd.DecRef(ctx) } vd = nextVD haveRef = true @@ -145,7 +145,7 @@ func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd haveRef := false defer func() { if haveRef { - vd.DecRef() + vd.DecRef(ctx) } }() unreachable := false @@ -157,13 +157,13 @@ loop: if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { break loop } - nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot) if !nextVD.Ok() { unreachable = true break loop } if haveRef { - vd.DecRef() + vd.DecRef(ctx) } vd = nextVD haveRef = true diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 9cb050597..d48520d58 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -16,6 +16,7 @@ package vfs import ( "math" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -25,6 +26,8 @@ import ( ) // AccessTypes is a bitmask of Unix file permissions. +// +// +stateify savable type AccessTypes uint16 // Bits in AccessTypes. @@ -183,7 +186,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { // CheckSetStat checks that creds has permission to change the metadata of a // file with the given permissions, UID, and GID as specified by stat, subject // to the rules of Linux's fs/attr.c:setattr_prepare(). -func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { +func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { + stat := &opts.Stat if stat.Mask&linux.STATX_SIZE != 0 { limit, err := CheckLimit(ctx, 0, int64(stat.Size)) if err != nil { @@ -215,6 +219,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat return syserror.EPERM } } + if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) { + if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil { + return err + } + } if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { if !CanActAsOwner(creds, kuid) { if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || @@ -265,7 +274,7 @@ func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth // operation must not proceed. Otherwise it returns the max length allowed to // without violating the limit. func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { - fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur + fileSizeLimit := limits.FromContextOrDie(ctx).Get(limits.FileSize).Cur if fileSizeLimit > math.MaxInt64 { return size, nil } @@ -278,3 +287,40 @@ func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { } return size, nil } + +// CheckXattrPermissions checks permissions for extended attribute access. +// This is analogous to fs/xattr.c:xattr_permission(). Some key differences: +// * Does not check for read-only filesystem property. +// * Does not check inode immutability or append only mode. In both cases EPERM +// must be returned by filesystem implementations. +// * Does not do inode permission checks. Filesystem implementations should +// handle inode permission checks as they may differ across implementations. +func CheckXattrPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, name string) error { + switch { + case strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX): + // The trusted.* namespace can only be accessed by privileged + // users. + if creds.HasCapability(linux.CAP_SYS_ADMIN) { + return nil + } + if ats.MayWrite() { + return syserror.EPERM + } + return syserror.ENODATA + case strings.HasPrefix(name, linux.XATTR_USER_PREFIX): + // In the user.* namespace, only regular files and directories can have + // extended attributes. For sticky directories, only the owner and + // privileged users can write attributes. + filetype := mode.FileType() + if filetype != linux.ModeRegular && filetype != linux.ModeDirectory { + if ats.MayWrite() { + return syserror.EPERM + } + return syserror.ENODATA + } + if filetype == linux.ModeDirectory && mode&linux.ModeSticky != 0 && ats.MayWrite() && !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + } + return nil +} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 9d047ff88..e4fd55012 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" @@ -34,6 +35,8 @@ import ( // FilesystemImpl methods. // // ResolvingPath is loosely analogous to Linux's struct nameidata. +// +// +stateify savable type ResolvingPath struct { vfs *VirtualFilesystem root VirtualDentry // refs borrowed from PathOperation @@ -87,6 +90,7 @@ func init() { // so error "constants" are really mutable vars, necessitating somewhat // expensive interface object comparisons. +// +stateify savable type resolveMountRootOrJumpError struct{} // Error implements error.Error. @@ -94,6 +98,7 @@ func (resolveMountRootOrJumpError) Error() string { return "resolving mount root or jump" } +// +stateify savable type resolveMountPointError struct{} // Error implements error.Error. @@ -101,6 +106,7 @@ func (resolveMountPointError) Error() string { return "resolving mount point" } +// +stateify savable type resolveAbsSymlinkError struct{} // Error implements error.Error. @@ -136,31 +142,31 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat return rp } -func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) { +func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) { rp.root = VirtualDentry{} - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = nil rp.start = nil - rp.releaseErrorState() + rp.releaseErrorState(ctx) resolvingPathPool.Put(rp) } -func (rp *ResolvingPath) decRefStartAndMount() { +func (rp *ResolvingPath) decRefStartAndMount(ctx context.Context) { if rp.flags&rpflagsHaveStartRef != 0 { - rp.start.DecRef() + rp.start.DecRef(ctx) } if rp.flags&rpflagsHaveMountRef != 0 { - rp.mount.DecRef() + rp.mount.DecRef(ctx) } } -func (rp *ResolvingPath) releaseErrorState() { +func (rp *ResolvingPath) releaseErrorState(ctx context.Context) { if rp.nextStart != nil { - rp.nextStart.DecRef() + rp.nextStart.DecRef(ctx) rp.nextStart = nil } if rp.nextMount != nil { - rp.nextMount.DecRef() + rp.nextMount.DecRef(ctx) rp.nextMount = nil } } @@ -236,13 +242,13 @@ func (rp *ResolvingPath) Advance() { // Restart resets the stream of path components represented by rp to its state // on entry to the current FilesystemImpl method. -func (rp *ResolvingPath) Restart() { +func (rp *ResolvingPath) Restart(ctx context.Context) { rp.pit = rp.origParts[rp.numOrigParts-1] rp.mustBeDir = rp.mustBeDirOrig rp.symlinks = rp.symlinksOrig rp.curPart = rp.numOrigParts - 1 copy(rp.parts[:], rp.origParts[:rp.numOrigParts]) - rp.releaseErrorState() + rp.releaseErrorState(ctx) } func (rp *ResolvingPath) relpathCommit() { @@ -260,13 +266,13 @@ func (rp *ResolvingPath) relpathCommit() { // Mount, CheckRoot returns (unspecified, non-nil error). Otherwise, path // resolution should resolve d's parent normally, and CheckRoot returns (false, // nil). -func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) { +func (rp *ResolvingPath) CheckRoot(ctx context.Context, d *Dentry) (bool, error) { if d == rp.root.dentry && rp.mount == rp.root.mount { // At contextual VFS root (due to e.g. chroot(2)). return true, nil } else if d == rp.mount.root { // At mount root ... - vd := rp.vfs.getMountpointAt(rp.mount, rp.root) + vd := rp.vfs.getMountpointAt(ctx, rp.mount, rp.root) if vd.Ok() { // ... of non-root mount. rp.nextMount = vd.mount @@ -283,11 +289,11 @@ func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) { // to d. If d is a mount point, such that path resolution should switch to // another Mount, CheckMount returns a non-nil error. Otherwise, CheckMount // returns nil. -func (rp *ResolvingPath) CheckMount(d *Dentry) error { +func (rp *ResolvingPath) CheckMount(ctx context.Context, d *Dentry) error { if !d.isMounted() { return nil } - if mnt := rp.vfs.getMountAt(rp.mount, d); mnt != nil { + if mnt := rp.vfs.getMountAt(ctx, rp.mount, d); mnt != nil { rp.nextMount = mnt return resolveMountPointError{} } @@ -389,11 +395,11 @@ func (rp *ResolvingPath) HandleJump(target VirtualDentry) error { return resolveMountRootOrJumpError{} } -func (rp *ResolvingPath) handleError(err error) bool { +func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { switch err.(type) { case resolveMountRootOrJumpError: // Switch to the new Mount. We hold references on the Mount and Dentry. - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = rp.nextMount rp.start = rp.nextStart rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef @@ -412,7 +418,7 @@ func (rp *ResolvingPath) handleError(err error) bool { case resolveMountPointError: // Switch to the new Mount. We hold a reference on the Mount, but // borrow the reference on the mount root from the Mount. - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = rp.nextMount rp.start = rp.nextMount.root rp.flags = rp.flags&^rpflagsHaveStartRef | rpflagsHaveMountRef @@ -423,12 +429,12 @@ func (rp *ResolvingPath) handleError(err error) bool { // path. rp.relpathCommit() // Restart path resolution on the new Mount. - rp.releaseErrorState() + rp.releaseErrorState(ctx) return true case resolveAbsSymlinkError: // Switch to the new Mount. References are borrowed from rp.root. - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = rp.root.mount rp.start = rp.root.dentry rp.flags &^= rpflagsHaveMountRef | rpflagsHaveStartRef @@ -440,7 +446,7 @@ func (rp *ResolvingPath) handleError(err error) bool { // path, including the symlink target we just prepended. rp.relpathCommit() // Restart path resolution on the new Mount. - rp.releaseErrorState() + rp.releaseErrorState(ctx) return true default: diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 58c7ad778..5bd756ea5 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -24,9 +24,9 @@ // Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry // VirtualFilesystem.filesystemsMu // EpollInstance.mu -// Inotify.mu -// Watches.mu -// Inotify.evMu +// Inotify.mu +// Watches.mu +// Inotify.evMu // VirtualFilesystem.fsTypesMu // // Locking Dentry.mu in multiple Dentries requires holding @@ -36,6 +36,7 @@ package vfs import ( "fmt" + "path" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -122,7 +123,10 @@ type VirtualFilesystem struct { } // Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes. -func (vfs *VirtualFilesystem) Init() error { +func (vfs *VirtualFilesystem) Init(ctx context.Context) error { + if vfs.mountpoints != nil { + panic("VFS already initialized") + } vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{}) vfs.devices = make(map[devTuple]*registeredDevice) vfs.anonBlockDevMinorNext = 1 @@ -142,7 +146,7 @@ func (vfs *VirtualFilesystem) Init() error { devMinor: anonfsDevMinor, } anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs) - defer anonfs.vfsfs.DecRef() + defer anonfs.vfsfs.DecRef(ctx) anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{}) if err != nil { // We should not be passing any MountOptions that would cause @@ -159,6 +163,8 @@ func (vfs *VirtualFilesystem) Init() error { // PathOperation is passed to VFS methods by pointer to reduce memory copying: // it's somewhat large and should never escape. (Options structs are passed by // pointer to VFS and FileDescription methods for the same reason.) +// +// +stateify savable type PathOperation struct { // Root is the VFS root. References on Root are borrowed from the provider // of the PathOperation. @@ -189,11 +195,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -211,11 +217,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede dentry: d, } rp.mount.IncRef() - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return vd, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return VirtualDentry{}, err } } @@ -233,7 +239,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } rp.mount.IncRef() name := rp.Component() - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return parentVD, name, nil } if checkInvariants { @@ -241,8 +247,8 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return VirtualDentry{}, "", err } } @@ -257,14 +263,14 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential } if !newpop.Path.Begin.Ok() { - oldVD.DecRef() + oldVD.DecRef(ctx) if newpop.Path.Absolute { return syserror.EEXIST } return syserror.ENOENT } if newpop.FollowFinalSymlink { - oldVD.DecRef() + oldVD.DecRef(ctx) ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink") return syserror.EINVAL } @@ -273,8 +279,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - vfs.putResolvingPath(rp) - oldVD.DecRef() + vfs.putResolvingPath(ctx, rp) + oldVD.DecRef(ctx) return nil } if checkInvariants { @@ -282,9 +288,9 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - oldVD.DecRef() + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) + oldVD.DecRef(ctx) return err } } @@ -293,6 +299,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential // MkdirAt creates a directory at the given path. func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with mkdirat(dirfd, "", mode). if pop.Path.Absolute { return syserror.EEXIST } @@ -310,7 +318,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -318,8 +326,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -329,6 +337,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia // error from the syserror package. func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with mknodat(dirfd, "", mode, dev). if pop.Path.Absolute { return syserror.EEXIST } @@ -343,7 +353,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -351,8 +361,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -405,31 +415,31 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) if opts.FileExec { if fd.Mount().Flags.NoExec { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EACCES } // Only a regular file can be executed. stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE}) if err != nil { - fd.DecRef() + fd.DecRef(ctx) return nil, err } if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EACCES } } - fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent) + fd.Dentry().InotifyWithParent(ctx, linux.IN_OPEN, 0, PathEvent) return fd, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return nil, err } } @@ -441,11 +451,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return target, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return "", err } } @@ -469,19 +479,19 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti return err } if oldName == "." || oldName == ".." { - oldParentVD.DecRef() + oldParentVD.DecRef(ctx) return syserror.EBUSY } if !newpop.Path.Begin.Ok() { - oldParentVD.DecRef() + oldParentVD.DecRef(ctx) if newpop.Path.Absolute { return syserror.EBUSY } return syserror.ENOENT } if newpop.FollowFinalSymlink { - oldParentVD.DecRef() + oldParentVD.DecRef(ctx) ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink") return syserror.EINVAL } @@ -494,8 +504,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - vfs.putResolvingPath(rp) - oldParentVD.DecRef() + vfs.putResolvingPath(ctx, rp) + oldParentVD.DecRef(ctx) return nil } if checkInvariants { @@ -503,9 +513,9 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - oldParentVD.DecRef() + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) + oldParentVD.DecRef(ctx) return err } } @@ -514,6 +524,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti // RmdirAt removes the directory at the given path. func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with unlinkat(dirfd, "", AT_REMOVEDIR). if pop.Path.Absolute { return syserror.EBUSY } @@ -528,7 +540,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -536,8 +548,8 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -549,11 +561,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -565,11 +577,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return stat, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return linux.Statx{}, err } } @@ -582,11 +594,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return statfs, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return linux.Statfs{}, err } } @@ -595,6 +607,8 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti // SymlinkAt creates a symbolic link at the given path with the given target. func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with symlinkat(oldpath, newdirfd, ""). if pop.Path.Absolute { return syserror.EEXIST } @@ -609,7 +623,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -617,8 +631,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -627,6 +641,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent // UnlinkAt deletes the non-directory file at the given path. func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with unlinkat(dirfd, "", 0). if pop.Path.Absolute { return syserror.EBUSY } @@ -641,7 +657,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -649,8 +665,8 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -658,17 +674,11 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti // BoundEndpointAt gets the bound endpoint at the given path, if one exists. func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) { - if !pop.Path.Begin.Ok() { - if pop.Path.Absolute { - return nil, syserror.ECONNREFUSED - } - return nil, syserror.ENOENT - } rp := vfs.getResolvingPath(creds, pop) for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return bep, nil } if checkInvariants { @@ -676,21 +686,21 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return nil, err } } } -// ListxattrAt returns all extended attribute names for the file at the given +// ListXattrAt returns all extended attribute names for the file at the given // path. -func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { +func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { rp := vfs.getResolvingPath(creds, pop) for { - names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size) + names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return names, nil } if err == syserror.ENOTSUP { @@ -698,61 +708,61 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return nil, err } } } -// GetxattrAt returns the value associated with the given extended attribute +// GetXattrAt returns the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) { +func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetXattrOptions) (string, error) { rp := vfs.getResolvingPath(creds, pop) for { - val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts) + val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return val, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return "", err } } } -// SetxattrAt changes the value associated with the given extended attribute +// SetXattrAt changes the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { +func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetXattrOptions) error { rp := vfs.getResolvingPath(creds, pop) for { - err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } } -// RemovexattrAt removes the given extended attribute from the file at rp. -func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { +// RemoveXattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { rp := vfs.getResolvingPath(creds, pop) for { - err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -774,11 +784,67 @@ func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { if err := fs.impl.Sync(ctx); err != nil && retErr == nil { retErr = err } - fs.DecRef() + fs.DecRef(ctx) } return retErr } +// MkdirAllAt recursively creates non-existent directories on the given path +// (including the last component). +func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string, root VirtualDentry, creds *auth.Credentials, mkdirOpts *MkdirOptions) error { + pop := &PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(currentPath), + } + stat, err := vfs.StatAt(ctx, creds, pop, &StatOptions{Mask: linux.STATX_TYPE}) + switch err { + case nil: + if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.FileTypeMask != linux.ModeDirectory { + return syserror.ENOTDIR + } + // Directory already exists. + return nil + case syserror.ENOENT: + // Expected, we will create the dir. + default: + return fmt.Errorf("stat failed for %q during directory creation: %w", currentPath, err) + } + + // Recurse to ensure parent is created and then create the final directory. + if err := vfs.MkdirAllAt(ctx, path.Dir(currentPath), root, creds, mkdirOpts); err != nil { + return err + } + if err := vfs.MkdirAt(ctx, creds, pop, mkdirOpts); err != nil { + return fmt.Errorf("failed to create directory %q: %w", currentPath, err) + } + return nil +} + +// MakeSyntheticMountpoint creates parent directories of target if they do not +// exist and attempts to create a directory for the mountpoint. If a +// non-directory file already exists there then we allow it. +func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, target string, root VirtualDentry, creds *auth.Credentials) error { + mkdirOpts := &MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true} + + // Make sure the parent directory of target exists. + if err := vfs.MkdirAllAt(ctx, path.Dir(target), root, creds, mkdirOpts); err != nil { + return fmt.Errorf("failed to create parent directory of mountpoint %q: %w", target, err) + } + + // Attempt to mkdir the final component. If a file (of any type) exists + // then we let allow mounting on top of that because we do not require the + // target to be an existing directory, unlike Linux mount(2). + if err := vfs.MkdirAt(ctx, creds, &PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(target), + }, mkdirOpts); err != nil && err != syserror.EEXIST { + return fmt.Errorf("failed to create mountpoint %q: %w", target, err) + } + return nil +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). @@ -828,9 +894,9 @@ func (vd VirtualDentry) IncRef() { // DecRef decrements the reference counts on the Mount and Dentry represented // by vd. -func (vd VirtualDentry) DecRef() { - vd.dentry.DecRef() - vd.mount.DecRef() +func (vd VirtualDentry) DecRef(ctx context.Context) { + vd.dentry.DecRef(ctx) + vd.mount.DecRef(ctx) } // Mount returns the Mount associated with vd. It does not take a reference on diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 748273366..bbafb8b7f 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -96,15 +96,33 @@ const ( Panic ) +// Set implements flag.Value. +func (a *Action) Set(v string) error { + switch v { + case "log", "logwarning": + *a = LogWarning + case "panic": + *a = Panic + default: + return fmt.Errorf("invalid watchdog action %q", v) + } + return nil +} + +// Get implements flag.Value. +func (a *Action) Get() interface{} { + return *a +} + // String returns Action's string representation. -func (a Action) String() string { - switch a { +func (a *Action) String() string { + switch *a { case LogWarning: - return "LogWarning" + return "logWarning" case Panic: - return "Panic" + return "panic" default: - panic(fmt.Sprintf("Invalid action: %d", a)) + panic(fmt.Sprintf("Invalid watchdog action: %d", *a)) } } diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD new file mode 100644 index 000000000..f08599ebd --- /dev/null +++ b/pkg/shim/runsc/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "runsc", + srcs = [ + "runsc.go", + "utils.go", + ], + visibility = ["//:sandbox"], + deps = [ + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go new file mode 100644 index 000000000..c5cf68efa --- /dev/null +++ b/pkg/shim/runsc/runsc.go @@ -0,0 +1,514 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "syscall" + "time" + + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +var Monitor runc.ProcessMonitor = runc.Monitor + +// DefaultCommand is the default command for Runsc. +const DefaultCommand = "runsc" + +// Runsc is the client to the runsc cli. +type Runsc struct { + Command string + PdeathSignal syscall.Signal + Setpgid bool + Root string + Log string + LogFormat runc.Format + Config map[string]string +} + +// List returns all containers created inside the provided runsc root directory. +func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { + data, err := cmdOutput(r.command(context, "list", "--format=json"), false) + if err != nil { + return nil, err + } + var out []*runc.Container + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return out, nil +} + +// State returns the state for the container provided by id. +func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) { + data, err := cmdOutput(r.command(context, "state", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var c runc.Container + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil +} + +type CreateOpts struct { + runc.IO + ConsoleSocket runc.ConsoleSocket + + // PidFile is a path to where a pid file should be created. + PidFile string + + // UserLog is a path to where runsc user log should be generated. + UserLog string +} + +func (o *CreateOpts) args() (out []string, err error) { + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.UserLog != "" { + out = append(out, "--user-log", o.UserLog) + } + return out, nil +} + +// Create creates a new container and returns its pid if it was created successfully. +func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error { + args := []string{"create", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +// Start will start an already created container. +func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error { + cmd := r.command(context, "start", id) + if cio != nil { + cio.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if cio != nil { + if c, ok := cio.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +type waitResult struct { + ID string `json:"id"` + ExitStatus int `json:"exitStatus"` +} + +// Wait will wait for a running container, and return its exit status. +// +// TODO(random-liu): Add exec process support. +func (r *Runsc) Wait(context context.Context, id string) (int, error) { + data, err := cmdOutput(r.command(context, "wait", id), true) + if err != nil { + return 0, fmt.Errorf("%s: %s", err, data) + } + var res waitResult + if err := json.Unmarshal(data, &res); err != nil { + return 0, err + } + return res.ExitStatus, nil +} + +type ExecOpts struct { + runc.IO + PidFile string + InternalPidFile string + ConsoleSocket runc.ConsoleSocket + Detach bool +} + +func (o *ExecOpts) args() (out []string, err error) { + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.Detach { + out = append(out, "--detach") + } + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.InternalPidFile != "" { + abs, err := filepath.Abs(o.InternalPidFile) + if err != nil { + return nil, err + } + out = append(out, "--internal-pid-file", abs) + } + return out, nil +} + +// Exec executes an additional process inside the container based on a full OCI +// Process specification. +func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error { + f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process") + if err != nil { + return err + } + defer os.Remove(f.Name()) + err = json.NewEncoder(f).Encode(spec) + f.Close() + if err != nil { + return err + } + args := []string{"exec", "--process", f.Name()} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err +} + +// Run runs the create, start, delete lifecycle of the container and returns +// its exit status after it has exited. +func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) { + args := []string{"run", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return -1, err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + ec, err := Monitor.Start(cmd) + if err != nil { + return -1, err + } + return Monitor.Wait(cmd, ec) +} + +type DeleteOpts struct { + Force bool +} + +func (o *DeleteOpts) args() (out []string) { + if o.Force { + out = append(out, "--force") + } + return out +} + +// Delete deletes the container. +func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error { + args := []string{"delete"} + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id)...)) +} + +// KillOpts specifies options for killing a container and its processes. +type KillOpts struct { + All bool + Pid int +} + +func (o *KillOpts) args() (out []string) { + if o.All { + out = append(out, "--all") + } + if o.Pid != 0 { + out = append(out, "--pid", strconv.Itoa(o.Pid)) + } + return out +} + +// Kill sends the specified signal to the container. +func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error { + args := []string{ + "kill", + } + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...)) +} + +// Stats return the stats for a container like cpu, memory, and I/O. +func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { + cmd := r.command(context, "events", "--stats", id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + defer func() { + rd.Close() + Monitor.Wait(cmd, ec) + }() + var e runc.Event + if err := json.NewDecoder(rd).Decode(&e); err != nil { + return nil, err + } + return e.Stats, nil +} + +// Events returns an event stream from runsc for a container with stats and OOM notifications. +func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) { + cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + rd.Close() + return nil, err + } + var ( + dec = json.NewDecoder(rd) + c = make(chan *runc.Event, 128) + ) + go func() { + defer func() { + close(c) + rd.Close() + Monitor.Wait(cmd, ec) + }() + for { + var e runc.Event + if err := dec.Decode(&e); err != nil { + if err == io.EOF { + return + } + e = runc.Event{ + Type: "error", + Err: err, + } + } + c <- &e + } + }() + return c, nil +} + +// Ps lists all the processes inside the container returning their pids. +func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var pids []int + if err := json.Unmarshal(data, &pids); err != nil { + return nil, err + } + return pids, nil +} + +// Top lists all the processes inside the container returning the full ps data. +func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + + topResults, err := runc.ParsePSOutput(data) + if err != nil { + return nil, fmt.Errorf("%s: ", err) + } + return topResults, nil +} + +func (r *Runsc) args() []string { + var args []string + if r.Root != "" { + args = append(args, fmt.Sprintf("--root=%s", r.Root)) + } + if r.Log != "" { + args = append(args, fmt.Sprintf("--log=%s", r.Log)) + } + if r.LogFormat != "" { + args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat)) + } + for k, v := range r.Config { + args = append(args, fmt.Sprintf("--%s=%s", k, v)) + } + return args +} + +// runOrError will run the provided command. +// +// If an error is encountered and neither Stdout or Stderr was set the error +// will be returned in the format of <error>: <stderr>. +func (r *Runsc) runOrError(cmd *exec.Cmd) error { + if cmd.Stdout != nil || cmd.Stderr != nil { + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err + } + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil +} + +func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd { + command := r.Command + if command == "" { + command = DefaultCommand + } + cmd := exec.CommandContext(context, command, append(r.args(), args...)...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: r.Setpgid, + } + if r.PdeathSignal != 0 { + cmd.SysProcAttr.Pdeathsig = r.PdeathSignal + } + + return cmd +} + +func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) { + b := getBuf() + defer putBuf(b) + + cmd.Stdout = b + if combined { + cmd.Stderr = b + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return b.Bytes(), err +} diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go new file mode 100644 index 000000000..c514b3bc7 --- /dev/null +++ b/pkg/shim/runsc/utils.go @@ -0,0 +1,44 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "bytes" + "strings" + "sync" +) + +var bytesBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, +} + +func getBuf() *bytes.Buffer { + return bytesBufferPool.Get().(*bytes.Buffer) +} + +func putBuf(b *bytes.Buffer) { + b.Reset() + bytesBufferPool.Put(b) +} + +// FormatLogPath parses runsc config, and fill in %ID% in the log path. +func FormatLogPath(id string, config map[string]string) { + if path, ok := config["debug-log"]; ok { + config["debug-log"] = strings.Replace(path, "%ID%", id, -1) + } +} diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD new file mode 100644 index 000000000..4377306af --- /dev/null +++ b/pkg/shim/v1/proc/BUILD @@ -0,0 +1,36 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "proc", + srcs = [ + "deleted_state.go", + "exec.go", + "exec_state.go", + "init.go", + "init_state.go", + "io.go", + "process.go", + "types.go", + "utils.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go new file mode 100644 index 000000000..d9b970c4d --- /dev/null +++ b/pkg/shim/v1/proc/deleted_state.go @@ -0,0 +1,49 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type deletedState struct{} + +func (*deletedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a deleted process.ss") +} + +func (*deletedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a deleted process.ss") +} + +func (*deletedState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error { + return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) SetExited(status int) {} + +func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a deleted state") +} diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go new file mode 100644 index 000000000..1d1d90488 --- /dev/null +++ b/pkg/shim/v1/proc/exec.go @@ -0,0 +1,281 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +type execProcess struct { + wg sync.WaitGroup + + execState execState + + mu sync.Mutex + id string + console console.Console + io runc.IO + status int + exited time.Time + pid int + internalPid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + path string + spec specs.Process + + parent *Init + waitBlock chan struct{} +} + +func (e *execProcess) Wait() { + <-e.waitBlock +} + +func (e *execProcess) ID() string { + return e.id +} + +func (e *execProcess) Pid() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.pid +} + +func (e *execProcess) ExitStatus() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.status +} + +func (e *execProcess) ExitedAt() time.Time { + e.mu.Lock() + defer e.mu.Unlock() + return e.exited +} + +func (e *execProcess) SetExited(status int) { + e.mu.Lock() + defer e.mu.Unlock() + + e.execState.SetExited(status) +} + +func (e *execProcess) setExited(status int) { + e.status = status + e.exited = time.Now() + e.parent.Platform.ShutdownConsole(context.Background(), e.console) + close(e.waitBlock) +} + +func (e *execProcess) Delete(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Delete(ctx) +} + +func (e *execProcess) delete(ctx context.Context) error { + e.wg.Wait() + if e.io != nil { + for _, c := range e.closers { + c.Close() + } + e.io.Close() + } + pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + // silently ignore error + os.Remove(pidfile) + internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + // silently ignore error + os.Remove(internalPidfile) + return nil +} + +func (e *execProcess) Resize(ws console.WinSize) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Resize(ws) +} + +func (e *execProcess) resize(ws console.WinSize) error { + if e.console == nil { + return nil + } + return e.console.Resize(ws) +} + +func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Kill(ctx, sig, false) +} + +func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error { + internalPid := e.internalPid + if internalPid != 0 { + if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{ + Pid: internalPid, + }); err != nil { + // If this returns error, consider the process has + // already stopped. + // + // TODO: Fix after signal handling is fixed. + return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) + } + } + return nil +} + +func (e *execProcess) Stdin() io.Closer { + return e.stdin +} + +func (e *execProcess) Stdio() stdio.Stdio { + return e.stdio +} + +func (e *execProcess) Start(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Start(ctx) +} + +func (e *execProcess) start(ctx context.Context) (err error) { + var ( + socket *runc.Socket + pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + ) + if e.stdio.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create runc console socket: %w", err) + } + defer socket.Close() + } else if e.stdio.IsNull() { + if e.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil { + return fmt.Errorf("failed to create runc io pipes: %w", err) + } + } + opts := &runsc.ExecOpts{ + PidFile: pidfile, + InternalPidFile: internalPidfile, + IO: e.io, + Detach: true, + } + if socket != nil { + opts.ConsoleSocket = socket + } + eventCh := e.parent.Monitor.Subscribe() + defer func() { + // Unsubscribe if an error is returned. + if err != nil { + e.parent.Monitor.Unsubscribe(eventCh) + } + }() + if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil { + close(e.waitBlock) + return e.parent.runtimeError(err, "OCI runtime exec failed") + } + if e.stdio.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err) + } + e.closers = append(e.closers, sc) + e.stdin = sc + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + } else if !e.stdio.IsNull() { + if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(opts.PidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err) + } + e.pid = pid + internalPid, err := runc.ReadPidFile(opts.InternalPidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err) + } + e.internalPid = internalPid + go func() { + defer e.parent.Monitor.Unsubscribe(eventCh) + for event := range eventCh { + if event.Pid == e.pid { + ExitCh <- Exit{ + Timestamp: event.Timestamp, + ID: e.id, + Status: event.Status, + } + break + } + } + }() + return nil +} + +func (e *execProcess) Status(ctx context.Context) (string, error) { + e.mu.Lock() + defer e.mu.Unlock() + // if we don't have a pid then the exec process has just been created + if e.pid == 0 { + return "created", nil + } + // if we have a pid and it can be signaled, the process is running + // TODO(random-liu): Use `runsc kill --pid`. + if err := unix.Kill(e.pid, 0); err == nil { + return "running", nil + } + // else if we have a pid but it can nolonger be signaled, it has stopped + return "stopped", nil +} diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go new file mode 100644 index 000000000..4dcda8b44 --- /dev/null +++ b/pkg/shim/v1/proc/exec_state.go @@ -0,0 +1,154 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" +) + +type execState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type execCreatedState struct { + p *execProcess +} + +func (s *execCreatedState) transition(name string) error { + switch name { + case "running": + s.p.execState = &execRunningState{p: s.p} + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execCreatedState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execCreatedState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + return err + } + return s.transition("running") +} + +func (s *execCreatedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execCreatedState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execRunningState struct { + p *execProcess +} + +func (s *execRunningState) transition(name string) error { + switch name { + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execRunningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execRunningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process") +} + +func (s *execRunningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process") +} + +func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execRunningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execStoppedState struct { + p *execProcess +} + +func (s *execStoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execStoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *execStoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process") +} + +func (s *execStoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execStoppedState) SetExited(status int) { + // no op +} diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go new file mode 100644 index 000000000..dab3123d6 --- /dev/null +++ b/pkg/shim/v1/proc/init.go @@ -0,0 +1,460 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +// InitPidFile name of the file that contains the init pid. +const InitPidFile = "init.pid" + +// Init represents an initial process for a container. +type Init struct { + wg sync.WaitGroup + initState initState + + // mu is used to ensure that `Start()` and `Exited()` calls return in + // the right order when invoked in separate go routines. This is the + // case within the shim implementation as it makes use of the reaper + // interface. + mu sync.Mutex + + waitBlock chan struct{} + + WorkDir string + + id string + Bundle string + console console.Console + Platform stdio.Platform + io runc.IO + runtime *runsc.Runsc + status int + exited time.Time + pid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + Rootfs string + IoUID int + IoGID int + Sandbox bool + UserLog string + Monitor ProcessMonitor +} + +// NewRunsc returns a new runsc instance for a process. +func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc { + if root == "" { + root = RunscRoot + } + return &runsc.Runsc{ + Command: runtime, + PdeathSignal: syscall.SIGKILL, + Log: filepath.Join(path, "log.json"), + LogFormat: runc.JSON, + Root: filepath.Join(root, namespace), + Config: config, + } +} + +// New returns a new init process. +func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init { + p := &Init{ + id: id, + runtime: runtime, + stdio: stdio, + status: 0, + waitBlock: make(chan struct{}), + } + p.initState = &createdState{p: p} + return p +} + +// Create the process with the provided config. +func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) { + var socket *runc.Socket + if r.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create OCI runtime console socket: %w", err) + } + defer socket.Close() + } else if hasNoIO(r) { + if p.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil { + return fmt.Errorf("failed to create OCI runtime io pipes: %w", err) + } + } + pidFile := filepath.Join(p.Bundle, InitPidFile) + opts := &runsc.CreateOpts{ + PidFile: pidFile, + } + if socket != nil { + opts.ConsoleSocket = socket + } + if p.Sandbox { + opts.IO = p.io + // UserLog is only useful for sandbox. + opts.UserLog = p.UserLog + } + if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil { + return p.runtimeError(err, "OCI runtime create failed") + } + if r.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err) + } + p.stdin = sc + p.closers = append(p.closers, sc) + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg) + if err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + p.console = console + } else if !hasNoIO(r) { + if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(pidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err) + } + p.pid = pid + return nil +} + +// Wait waits for the process to exit. +func (p *Init) Wait() { + <-p.waitBlock +} + +// ID returns the ID of the process. +func (p *Init) ID() string { + return p.id +} + +// Pid returns the PID of the process. +func (p *Init) Pid() int { + return p.pid +} + +// ExitStatus returns the exit status of the process. +func (p *Init) ExitStatus() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.status +} + +// ExitedAt returns the time when the process exited. +func (p *Init) ExitedAt() time.Time { + p.mu.Lock() + defer p.mu.Unlock() + return p.exited +} + +// Status returns the status of the process. +func (p *Init) Status(ctx context.Context) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + c, err := p.runtime.State(ctx, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return "stopped", nil + } + return "", p.runtimeError(err, "OCI runtime state failed") + } + return p.convertStatus(c.Status), nil +} + +// Start starts the init process. +func (p *Init) Start(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Start(ctx) +} + +func (p *Init) start(ctx context.Context) error { + var cio runc.IO + if !p.Sandbox { + cio = p.io + } + if err := p.runtime.Start(ctx, p.id, cio); err != nil { + return p.runtimeError(err, "OCI runtime start failed") + } + go func() { + status, err := p.runtime.Wait(context.Background(), p.id) + if err != nil { + log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id) + // TODO(random-liu): Handle runsc kill error. + if err := p.killAll(ctx); err != nil { + log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id) + } + status = internalErrorCode + } + ExitCh <- Exit{ + Timestamp: time.Now(), + ID: p.id, + Status: status, + } + }() + return nil +} + +// SetExited set the exit stauts of the init process. +func (p *Init) SetExited(status int) { + p.mu.Lock() + defer p.mu.Unlock() + + p.initState.SetExited(status) +} + +func (p *Init) setExited(status int) { + p.exited = time.Now() + p.status = status + p.Platform.ShutdownConsole(context.Background(), p.console) + close(p.waitBlock) +} + +// Delete deletes the init process. +func (p *Init) Delete(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Delete(ctx) +} + +func (p *Init) delete(ctx context.Context) error { + p.killAll(ctx) + p.wg.Wait() + err := p.runtime.Delete(ctx, p.id, nil) + // ignore errors if a runtime has already deleted the process + // but we still hold metadata and pipes + // + // this is common during a checkpoint, runc will delete the container state + // after a checkpoint and the container will no longer exist within runc + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + err = nil + } else { + err = p.runtimeError(err, "failed to delete task") + } + } + if p.io != nil { + for _, c := range p.closers { + c.Close() + } + p.io.Close() + } + if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount") + if err == nil { + err = fmt.Errorf("failed rootfs umount: %w", err2) + } + } + return err +} + +// Resize resizes the init processes console. +func (p *Init) Resize(ws console.WinSize) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +func (p *Init) resize(ws console.WinSize) error { + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +// Kill kills the init process. +func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Kill(ctx, signal, all) +} + +func (p *Init) kill(context context.Context, signal uint32, all bool) error { + var ( + killErr error + backoff = 100 * time.Millisecond + ) + timeout := 1 * time.Second + for start := time.Now(); time.Now().Sub(start) < timeout; { + c, err := p.runtime.State(context, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + return p.runtimeError(err, "OCI runtime state failed") + } + // For runsc, signal only works when container is running state. + // If the container is not in running state, directly return + // "no such process" + if p.convertStatus(c.Status) == "stopped" { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{ + All: all, + }) + if killErr == nil { + return nil + } + time.Sleep(backoff) + backoff *= 2 + } + return p.runtimeError(killErr, "kill timeout") +} + +// KillAll kills all processes belonging to the init process. +func (p *Init) KillAll(context context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + return p.killAll(context) +} + +func (p *Init) killAll(context context.Context) error { + p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{ + All: true, + }) + // Ignore error handling for `runsc kill --all` for now. + // * If it doesn't return error, it is good; + // * If it returns error, consider the container has already stopped. + // TODO: Fix `runsc kill --all` error handling. + return nil +} + +// Stdin returns the stdin of the process. +func (p *Init) Stdin() io.Closer { + return p.stdin +} + +// Runtime returns the OCI runtime configured for the init process. +func (p *Init) Runtime() *runsc.Runsc { + return p.runtime +} + +// Exec returns a new child process. +func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Exec(ctx, path, r) +} + +// exec returns a new exec'd process. +func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + // process exec request + var spec specs.Process + if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { + return nil, err + } + spec.Terminal = r.Terminal + + e := &execProcess{ + id: r.ID, + path: path, + parent: p, + spec: spec, + stdio: stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }, + waitBlock: make(chan struct{}), + } + e.execState = &execCreatedState{p: e} + return e, nil +} + +// Stdio returns the stdio of the process. +func (p *Init) Stdio() stdio.Stdio { + return p.stdio +} + +func (p *Init) runtimeError(rErr error, msg string) error { + if rErr == nil { + return nil + } + + rMsg, err := getLastRuntimeError(p.runtime) + switch { + case err != nil: + return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err) + case rMsg == "": + return fmt.Errorf("%s: %w", msg, rErr) + default: + return fmt.Errorf("%s: %s", msg, rMsg) + } +} + +func (p *Init) convertStatus(status string) string { + if status == "created" && !p.Sandbox && p.status == internalErrorCode { + // Treat start failure state for non-root container as stopped. + return "stopped" + } + return status +} + +func withConditionalIO(c stdio.Stdio) runc.IOOpt { + return func(o *runc.IOOption) { + o.OpenStdin = c.Stdin != "" + o.OpenStdout = c.Stdout != "" + o.OpenStderr = c.Stderr != "" + } +} diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go new file mode 100644 index 000000000..9233ecc85 --- /dev/null +++ b/pkg/shim/v1/proc/init_state.go @@ -0,0 +1,182 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type initState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Exec(context.Context, string, *ExecConfig) (process.Process, error) + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type createdState struct { + p *Init +} + +func (s *createdState) transition(name string) error { + switch name { + case "running": + s.p.initState = &runningState{p: s.p} + case "stopped": + s.p.initState = &stoppedState{p: s.p} + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *createdState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *createdState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + // Containerd doesn't allow deleting container in created state. + // However, for gvisor, a non-root container in created state can + // only go to running state. If the container can't be started, + // it can only stay in created state, and never be deleted. + // To work around that, we treat non-root container in start failure + // state as stopped. + if !s.p.Sandbox { + s.p.io.Close() + s.p.setExited(internalErrorCode) + if err := s.transition("stopped"); err != nil { + panic(err) + } + } + return err + } + return s.transition("running") +} + +func (s *createdState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *createdState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type runningState struct { + p *Init +} + +func (s *runningState) transition(name string) error { + switch name { + case "stopped": + s.p.initState = &stoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *runningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *runningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process.ss") +} + +func (s *runningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process.ss") +} + +func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *runningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type stoppedState struct { + p *Init +} + +func (s *stoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *stoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *stoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process.ss") +} + +func (s *stoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id) +} + +func (s *stoppedState) SetExited(status int) { + // no op +} + +func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a stopped state") +} diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go new file mode 100644 index 000000000..34d825fb7 --- /dev/null +++ b/pkg/shim/v1/proc/io.go @@ -0,0 +1,162 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "syscall" + + "github.com/containerd/containerd/log" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" +) + +// TODO(random-liu): This file can be a util. + +var bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, +} + +func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error { + var sameFile *countingWriteCloser + for _, i := range []struct { + name string + dest func(wc io.WriteCloser, rc io.Closer) + }{ + { + name: stdout, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil { + log.G(ctx).Warn("error copying stdout") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, { + name: stderr, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil { + log.G(ctx).Warn("error copying stderr") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, + } { + ok, err := isFifo(i.name) + if err != nil { + return err + } + var ( + fw io.WriteCloser + fr io.Closer + ) + if ok { + if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + } else { + if sameFile != nil { + sameFile.count++ + i.dest(sameFile, nil) + continue + } + if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if stdout == stderr { + sameFile = &countingWriteCloser{ + WriteCloser: fw, + count: 1, + } + } + } + i.dest(fw, fr) + } + if stdin == "" { + return nil + } + f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err) + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + + io.CopyBuffer(rio.Stdin(), f, *p) + rio.Stdin().Close() + f.Close() + }() + return nil +} + +// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times. +type countingWriteCloser struct { + io.WriteCloser + count int64 +} + +func (c *countingWriteCloser) Close() error { + if atomic.AddInt64(&c.count, -1) > 0 { + return nil + } + return c.WriteCloser.Close() +} + +// isFifo checks if a file is a fifo. +// +// If the file does not exist then it returns false. +func isFifo(path string) (bool, error) { + stat, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe { + return true, nil + } + return false, nil +} diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go new file mode 100644 index 000000000..d462c3eef --- /dev/null +++ b/pkg/shim/v1/proc/process.go @@ -0,0 +1,37 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "fmt" +) + +// RunscRoot is the path to the root runsc state directory. +const RunscRoot = "/run/containerd/runsc" + +func stateName(v interface{}) string { + switch v.(type) { + case *runningState, *execRunningState: + return "running" + case *createdState, *execCreatedState: + return "created" + case *deletedState: + return "deleted" + case *stoppedState: + return "stopped" + } + panic(fmt.Errorf("invalid state %v", v)) +} diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go new file mode 100644 index 000000000..2b0df4663 --- /dev/null +++ b/pkg/shim/v1/proc/types.go @@ -0,0 +1,69 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "time" + + runc "github.com/containerd/go-runc" + "github.com/gogo/protobuf/types" +) + +// Mount holds filesystem mount configuration. +type Mount struct { + Type string + Source string + Target string + Options []string +} + +// CreateConfig hold task creation configuration. +type CreateConfig struct { + ID string + Bundle string + Runtime string + Rootfs []Mount + Terminal bool + Stdin string + Stdout string + Stderr string + Options *types.Any +} + +// ExecConfig holds exec creation configuration. +type ExecConfig struct { + ID string + Terminal bool + Stdin string + Stdout string + Stderr string + Spec *types.Any +} + +// Exit is the type of exit events. +type Exit struct { + Timestamp time.Time + ID string + Status int +} + +// ProcessMonitor monitors process exit changes. +type ProcessMonitor interface { + // Subscribe to process exit changes + Subscribe() chan runc.Exit + // Unsubscribe to process exit changes + Unsubscribe(c chan runc.Exit) +} diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go new file mode 100644 index 000000000..716de2f59 --- /dev/null +++ b/pkg/shim/v1/proc/utils.go @@ -0,0 +1,90 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "encoding/json" + "io" + "os" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +const ( + internalErrorCode = 128 + bufferSize = 32 +) + +// ExitCh is the exit events channel for containers and exec processes +// inside the sandbox. +var ExitCh = make(chan Exit, bufferSize) + +// TODO(mlaventure): move to runc package? +func getLastRuntimeError(r *runsc.Runsc) (string, error) { + if r.Log == "" { + return "", nil + } + + f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400) + if err != nil { + return "", err + } + + var ( + errMsg string + log struct { + Level string + Msg string + Time time.Time + } + ) + + dec := json.NewDecoder(f) + for err = nil; err == nil; { + if err = dec.Decode(&log); err != nil && err != io.EOF { + return "", err + } + if log.Level == "error" { + errMsg = strings.TrimSpace(log.Msg) + } + } + + return errMsg, nil +} + +func copyFile(to, from string) error { + ff, err := os.Open(from) + if err != nil { + return err + } + defer ff.Close() + tt, err := os.Create(to) + if err != nil { + return err + } + defer tt.Close() + + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + _, err = io.CopyBuffer(tt, ff, *p) + return err +} + +func hasNoIO(r *CreateConfig) bool { + return r.Stdin == "" && r.Stdout == "" && r.Stderr == "" +} diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD new file mode 100644 index 000000000..05c595bc9 --- /dev/null +++ b/pkg/shim/v1/shim/BUILD @@ -0,0 +1,40 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "shim", + srcs = [ + "api.go", + "platform.go", + "service.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", + ], +) diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go new file mode 100644 index 000000000..5dd8ff172 --- /dev/null +++ b/pkg/shim/v1/shim/api.go @@ -0,0 +1,28 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskCreate = events.TaskCreate +type TaskStart = events.TaskStart +type TaskOOM = events.TaskOOM +type TaskExit = events.TaskExit +type TaskDelete = events.TaskDelete +type TaskExecAdded = events.TaskExecAdded +type TaskExecStarted = events.TaskExecStarted diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go new file mode 100644 index 000000000..f590f80ef --- /dev/null +++ b/pkg/shim/v1/shim/platform.go @@ -0,0 +1,106 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *Service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go new file mode 100644 index 000000000..84a810cb2 --- /dev/null +++ b/pkg/shim/v1/shim/service.go @@ -0,0 +1,573 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/containerd/console" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + shim "github.com/containerd/containerd/runtime/v1/shim/v1" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +// Config contains shim specific configuration. +type Config struct { + Path string + Namespace string + WorkDir string + RuntimeRoot string + RunscConfig map[string]string +} + +// NewService returns a new shim service that can be used via GRPC. +func NewService(config Config, publisher events.Publisher) (*Service, error) { + if config.Namespace == "" { + return nil, fmt.Errorf("shim namespace cannot be empty") + } + ctx := namespaces.WithNamespace(context.Background(), config.Namespace) + s := &Service{ + config: config, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + } + go s.processExits() + if err := s.initPlatform(); err != nil { + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// Service is the shim implementation of a remote shim over GRPC. +type Service struct { + mu sync.Mutex + + config Config + context context.Context + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + ec chan proc.Exit + + // Filled by Create() + id string + bundle string +} + +// Create creates a new initial process and container with the underlying OCI runtime. +func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: r.Runtime, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + defer func() { + if err != nil { + if err2 := mount.UnmountAll(rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount") + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + s.config.Path, + s.config.WorkDir, + s.config.RuntimeRoot, + s.config.Namespace, + s.config.RunscConfig, + s.platform, + config, + ) + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + pid := process.Pid() + s.processes[r.ID] = process + return &shim.CreateTaskResponse{ + Pid: uint32(pid), + }, nil +} + +// Start starts a process. +func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + return &shim.StartResponse{ + ID: p.ID(), + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, s.id) + s.mu.Unlock() + s.platform.Close() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// DeleteProcess deletes an exec'd process. +func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) { + if r.ID == s.id { + return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess") + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, r.ID) + s.mu.Unlock() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + + if p := s.processes[r.ID]; p != nil { + s.mu.Unlock() + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID) + } + + p := s.processes[s.id] + s.mu.Unlock() + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + + process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{ + ID: r.ID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resises the terminal of a process. +func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) { + if r.ID == "" { + return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided") + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &shim.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause pauses the container. +func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume resumes the container. +func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill kills a process with the provided signal. +func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) { + if r.ID == "" { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil + } + + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// ListPids returns all pids inside the container. +func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &shim.ListPidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// ShimInfo returns shim information such as the shim's pid. +func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) { + return &shim.ShimInfoResponse{ + ShimPid: uint32(os.Getpid()), + }, nil +} + +// Update updates a running container. +func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + p.Wait() + + return &shim.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *Service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *Service) allProcesses() []process.Process { + s.mu.Lock() + defer s.mu.Unlock() + + res := make([]process.Process, 0, len(s.processes)) + for _, p := range s.processes { + res = append(res, p) + } + return res +} + +func (s *Service) checkProcesses(e proc.Exit) { + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *Service) forward(publisher events.Publisher) { + for e := range s.events { + if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil { + log.G(s.context).WithError(err).Error("post event") + } + } +} + +// getInitProcess returns the init process. +func (s *Service) getInitProcess() (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[s.id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + return p, nil +} + +// getExecProcess returns the given exec process. +func (s *Service) getExecProcess(id string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id) + } + return p, nil +} + +func getTopic(ctx context.Context, e interface{}) string { + switch e.(type) { + case *TaskCreate: + return runtime.TaskCreateEventTopic + case *TaskStart: + return runtime.TaskStartEventTopic + case *TaskOOM: + return runtime.TaskOOMEventTopic + case *TaskExit: + return runtime.TaskExitEventTopic + case *TaskDelete: + return runtime.TaskDeleteEventTopic + case *TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) { + var options runctypes.CreateOptions + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + options = *v.(*runctypes.CreateOptions) + } + + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + + runsc.FormatLogPath(r.ID, config) + rootfs := filepath.Join(path, "rootfs") + runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = utils.IsSandbox(spec) + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD new file mode 100644 index 000000000..54a0aabb7 --- /dev/null +++ b/pkg/shim/v1/utils/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "utils", + srcs = [ + "annotations.go", + "utils.go", + "volumes.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) + +go_test( + name = "utils_test", + size = "small", + srcs = ["volumes_test.go"], + library = ":utils", + deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], +) diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go new file mode 100644 index 000000000..1e9d3f365 --- /dev/null +++ b/pkg/shim/v1/utils/annotations.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +// Annotations from the CRI annotations package. +// +// These are vendor due to import conflicts. +const ( + sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory" + containerTypeAnnotation = "io.kubernetes.cri.container-type" + containerTypeSandbox = "sandbox" + containerTypeContainer = "container" +) diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go new file mode 100644 index 000000000..07e346654 --- /dev/null +++ b/pkg/shim/v1/utils/utils.go @@ -0,0 +1,56 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "io/ioutil" + "os" + "path/filepath" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +// ReadSpec reads OCI spec from the bundle directory. +func ReadSpec(bundle string) (*specs.Spec, error) { + f, err := os.Open(filepath.Join(bundle, "config.json")) + if err != nil { + return nil, err + } + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + return nil, err + } + return &spec, nil +} + +// IsSandbox checks whether a container is a sandbox container. +func IsSandbox(spec *specs.Spec) bool { + t, ok := spec.Annotations[containerTypeAnnotation] + return !ok || t == containerTypeSandbox +} + +// UserLogPath gets user log path from OCI annotation. +func UserLogPath(spec *specs.Spec) string { + sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "" + } + return filepath.Join(sandboxLogDir, "gvisor.log") +} diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go new file mode 100644 index 000000000..52a428179 --- /dev/null +++ b/pkg/shim/v1/utils/volumes.go @@ -0,0 +1,155 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const volumeKeyPrefix = "dev.gvisor.spec.mount." + +var kubeletPodsDir = "/var/lib/kubelet/pods" + +// volumeName gets volume name from volume annotation key, example: +// dev.gvisor.spec.mount.NAME.share +func volumeName(k string) string { + return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0] +} + +// volumeFieldName gets volume field name from volume annotation key, example: +// `type` is the field of dev.gvisor.spec.mount.NAME.type +func volumeFieldName(k string) string { + parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".") + return parts[len(parts)-1] +} + +// podUID gets pod UID from the pod log path. +func podUID(s *specs.Spec) (string, error) { + sandboxLogDir := s.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "", fmt.Errorf("no sandbox log path annotation") + } + fields := strings.Split(filepath.Base(sandboxLogDir), "_") + switch len(fields) { + case 1: // This is the old CRI logging path. + return fields[0], nil + case 3: // This is the new CRI logging path. + return fields[2], nil + } + return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir) +} + +// isVolumeKey checks whether an annotation key is for volume. +func isVolumeKey(k string) bool { + return strings.HasPrefix(k, volumeKeyPrefix) +} + +// volumeSourceKey constructs the annotation key for volume source. +func volumeSourceKey(volume string) string { + return volumeKeyPrefix + volume + ".source" +} + +// volumePath searches the volume path in the kubelet pod directory. +func volumePath(volume, uid string) (string, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume) + dirs, err := filepath.Glob(volumeSearchPath) + if err != nil { + return "", err + } + if len(dirs) != 1 { + return "", fmt.Errorf("unexpected matched volume list %v", dirs) + } + return dirs[0], nil +} + +// isVolumePath checks whether a string is the volume path. +func isVolumePath(volume, path string) (bool, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume) + return filepath.Match(volumeSearchPath, path) +} + +// UpdateVolumeAnnotations add necessary OCI annotations for gvisor +// volume optimization. +func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { + var ( + uid string + err error + ) + if IsSandbox(s) { + uid, err = podUID(s) + if err != nil { + // Skip if we can't get pod UID, because this doesn't work + // for containerd 1.1. + return nil + } + } + var updated bool + for k, v := range s.Annotations { + if !isVolumeKey(k) { + continue + } + if volumeFieldName(k) != "type" { + continue + } + volume := volumeName(k) + if uid != "" { + // This is a sandbox. + path, err := volumePath(volume, uid) + if err != nil { + return fmt.Errorf("get volume path for %q: %w", volume, err) + } + s.Annotations[volumeSourceKey(volume)] = path + updated = true + } else { + // This is a container. + for i := range s.Mounts { + // An error is returned for sandbox if source + // annotation is not successfully applied, so + // it is guaranteed that the source annotation + // for sandbox has already been successfully + // applied at this point. + // + // The volume name is unique inside a pod, so + // matching without podUID is fine here. + // + // TODO: Pass podUID down to shim for containers to do + // more accurate matching. + if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { + // gVisor requires the container mount type to match + // sandbox mount type. + s.Mounts[i].Type = v + updated = true + } + } + } + } + if !updated { + return nil + } + // Update bundle. + b, err := json.Marshal(s) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) +} diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go new file mode 100644 index 000000000..3e02c6151 --- /dev/null +++ b/pkg/shim/v1/utils/volumes_test.go @@ -0,0 +1,308 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +func TestUpdateVolumeAnnotations(t *testing.T) { + dir, err := ioutil.TempDir("", "test-update-volume-annotations") + if err != nil { + t.Fatalf("create tempdir: %v", err) + } + defer os.RemoveAll(dir) + kubeletPodsDir = dir + + const ( + testPodUID = "testuid" + testVolumeName = "testvolume" + testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID + testLegacyLogDirPath = "/var/log/pods/" + testPodUID + ) + testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName) + + if err := os.MkdirAll(testVolumePath, 0755); err != nil { + t.Fatalf("Create test volume: %v", err) + } + + for _, test := range []struct { + desc string + spec *specs.Spec + expected *specs.Spec + expectErr bool + expectUpdate bool + }{ + { + desc: "volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "volume annotations for sandbox with legacy log path", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "tmpfs: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "bind: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "should not return error without pod log directory", + spec: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + }, + { + desc: "should return error if volume path does not exist", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount.notexist.share": "pod", + "dev.gvisor.spec.mount.notexist.type": "tmpfs", + "dev.gvisor.spec.mount.notexist.options": "ro", + }, + }, + expectErr: true, + }, + { + desc: "no volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + }, + { + desc: "no volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + bundle, err := ioutil.TempDir(dir, "test-bundle") + if err != nil { + t.Fatalf("Create test bundle: %v", err) + } + err = UpdateVolumeAnnotations(bundle, test.spec) + if test.expectErr { + if err == nil { + t.Fatal("Expected error, but got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(test.expected, test.spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) + } + if test.expectUpdate { + b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json")) + if err != nil { + t.Fatalf("Read spec from bundle: %v", err) + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + t.Fatalf("Unmarshal spec: %v", err) + } + if !reflect.DeepEqual(test.expected, &spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, &spec) + } + } + }) + } +} diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD new file mode 100644 index 000000000..7e0a114a0 --- /dev/null +++ b/pkg/shim/v2/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "v2", + srcs = [ + "api.go", + "epoll.go", + "service.go", + "service_linux.go", + ], + visibility = ["//shim:__subpackages__"], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "//pkg/shim/v2/options", + "//pkg/shim/v2/runtimeoptions", + "//runsc/specutils", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + "@com_github_containerd_containerd//runtime/v2/task:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go new file mode 100644 index 000000000..dbe5c59f6 --- /dev/null +++ b/pkg/shim/v2/api.go @@ -0,0 +1,22 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskOOM = events.TaskOOM diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go new file mode 100644 index 000000000..41232cca8 --- /dev/null +++ b/pkg/shim/v2/epoll.go @@ -0,0 +1,129 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "sync" + + "github.com/containerd/cgroups" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/runtime" + "golang.org/x/sys/unix" +) + +func newOOMEpoller(publisher events.Publisher) (*epoller, error) { + fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + return &epoller{ + fd: fd, + publisher: publisher, + set: make(map[uintptr]*item), + }, nil +} + +type epoller struct { + mu sync.Mutex + + fd int + publisher events.Publisher + set map[uintptr]*item +} + +type item struct { + id string + cg cgroups.Cgroup +} + +func (e *epoller) Close() error { + return unix.Close(e.fd) +} + +func (e *epoller) run(ctx context.Context) { + var events [128]unix.EpollEvent + for { + select { + case <-ctx.Done(): + e.Close() + return + default: + n, err := unix.EpollWait(e.fd, events[:], -1) + if err != nil { + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + // Should not happen. + panic(fmt.Errorf("cgroups: epoll wait: %w", err)) + } + for i := 0; i < n; i++ { + e.process(ctx, uintptr(events[i].Fd)) + } + } + } +} + +func (e *epoller) add(id string, cg cgroups.Cgroup) error { + e.mu.Lock() + defer e.mu.Unlock() + fd, err := cg.OOMEventFD() + if err != nil { + return err + } + e.set[fd] = &item{ + id: id, + cg: cg, + } + event := unix.EpollEvent{ + Fd: int32(fd), + Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR, + } + return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event) +} + +func (e *epoller) process(ctx context.Context, fd uintptr) { + flush(fd) + e.mu.Lock() + i, ok := e.set[fd] + if !ok { + e.mu.Unlock() + return + } + e.mu.Unlock() + if i.cg.State() == cgroups.Deleted { + e.mu.Lock() + delete(e.set, fd) + e.mu.Unlock() + unix.Close(int(fd)) + return + } + if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{ + ContainerID: i.id, + }); err != nil { + // Should not happen. + panic(fmt.Errorf("publish OOM event: %w", err)) + } +} + +func flush(fd uintptr) error { + var buf [8]byte + _, err := unix.Read(int(fd), buf[:]) + return err +} diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD new file mode 100644 index 000000000..ca212e874 --- /dev/null +++ b/pkg/shim/v2/options/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "options", + srcs = [ + "options.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go new file mode 100644 index 000000000..de09f2f79 --- /dev/null +++ b/pkg/shim/v2/options/options.go @@ -0,0 +1,33 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +const OptionType = "io.containerd.runsc.v1.options" + +// Options is runtime options for io.containerd.runsc.v1. +type Options struct { + // ShimCgroup is the cgroup the shim should be in. + ShimCgroup string `toml:"shim_cgroup"` + // IoUid is the I/O's pipes uid. + IoUid uint32 `toml:"io_uid"` + // IoUid is the I/O's pipes gid. + IoGid uint32 `toml:"io_gid"` + // BinaryName is the binary name of the runsc binary. + BinaryName string `toml:"binary_name"` + // Root is the runsc root directory. + Root string `toml:"root"` + // RunscConfig is a key/value map of all runsc flags. + RunscConfig map[string]string `toml:"runsc_config"` +} diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD new file mode 100644 index 000000000..ba2ed1ea7 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") + +package(licenses = ["notice"]) + +proto_library( + name = "api", + srcs = [ + "runtimeoptions.proto", + ], +) + +go_library( + name = "runtimeoptions", + srcs = ["runtimeoptions.go"], + visibility = ["//pkg/shim/v2:__pkg__"], + deps = [ + ":api_go_proto", + "@com_github_gogo_protobuf//proto:go_default_library", + ], +) + +go_test( + name = "runtimeoptions_test", + size = "small", + srcs = ["runtimeoptions_test.go"], + library = ":runtimeoptions", + deps = [ + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_golang_protobuf//proto:go_default_library", + ], +) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go new file mode 100644 index 000000000..aaf17b87a --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -0,0 +1,30 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + proto "github.com/gogo/protobuf/proto" + pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto" +) + +type Options = pb.Options + +func init() { + // The generated proto file auto registers with "golang/protobuf/proto" + // package. However, typeurl uses "golang/gogo/protobuf/proto". So registers + // the type there too. + proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto new file mode 100644 index 000000000..057032e34 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package cri.runtimeoptions.v1; + +// This is a version of the runtimeoptions CRI API that is vendored. +// +// Importing the full CRI package is a nightmare. +message Options { + string type_url = 1; + string config_path = 2; +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go new file mode 100644 index 000000000..f4c238a00 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + "testing" + + shim "github.com/containerd/containerd/runtime/v1/shim/v1" + "github.com/containerd/typeurl" + "github.com/golang/protobuf/proto" +) + +func TestCreateTaskRequest(t *testing.T) { + // Serialize the top-level message. + const encodedText = `options: < + type_url: "cri.runtimeoptions.v1.Options" + value: "\n\010type_url\022\013config_path" +>` + got := &shim.CreateTaskRequest{} // Should have raw options. + if err := proto.UnmarshalText(encodedText, got); err != nil { + t.Fatalf("unable to unmarshal text: %v", err) + } + t.Logf("got: %s", proto.MarshalTextString(got)) + + // Check the options. + wantOptions := &Options{} + wantOptions.TypeUrl = "type_url" + wantOptions.ConfigPath = "config_path" + gotMessage, err := typeurl.UnmarshalAny(got.Options) + if err != nil { + t.Fatalf("unable to unmarshal any: %v", err) + } + gotOptions, ok := gotMessage.(*Options) + if !ok { + t.Fatalf("got %v, want %v", gotMessage, wantOptions) + } + if !proto.Equal(gotOptions, wantOptions) { + t.Fatalf("got %v, want %v", gotOptions, wantOptions) + } +} diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go new file mode 100644 index 000000000..1534152fc --- /dev/null +++ b/pkg/shim/v2/service.go @@ -0,0 +1,824 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/BurntSushi/toml" + "github.com/containerd/cgroups" + "github.com/containerd/console" + "github.com/containerd/containerd/api/events" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + "github.com/containerd/containerd/runtime/v2/shim" + taskAPI "github.com/containerd/containerd/runtime/v2/task" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" + "gvisor.dev/gvisor/pkg/shim/v2/options" + "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions" + "gvisor.dev/gvisor/runsc/specutils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +var _ = (taskAPI.TaskService)(&service{}) + +// configFile is the default config file name. For containerd 1.2, +// we assume that a config.toml should exist in the runtime root. +const configFile = "config.toml" + +// New returns a new shim service that can be used via GRPC. +func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { + ep, err := newOOMEpoller(publisher) + if err != nil { + return nil, err + } + go ep.run(ctx) + s := &service{ + id: id, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + oomPoller: ep, + cancel: cancel, + } + go s.processExits() + runsc.Monitor = reaper.Default + if err := s.initPlatform(); err != nil { + cancel() + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// service is the shim implementation of a remote shim over GRPC. +type service struct { + mu sync.Mutex + + context context.Context + task process.Process + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + opts options.Options + ec chan proc.Exit + oomPoller *epoller + + id string + bundle string + cancel func() +} + +func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + self, err := os.Executable() + if err != nil { + return nil, err + } + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + args := []string{ + "-namespace", ns, + "-address", containerdAddress, + "-publish-binary", containerdBinary, + } + cmd := exec.Command(self, args...) + cmd.Dir = cwd + cmd.Env = append(os.Environ(), "GOMAXPROCS=2") + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + return cmd, nil +} + +func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) { + cmd, err := newCommand(ctx, containerdBinary, containerdAddress) + if err != nil { + return "", err + } + address, err := shim.SocketAddress(ctx, id) + if err != nil { + return "", err + } + socket, err := shim.NewSocket(address) + if err != nil { + return "", err + } + defer socket.Close() + f, err := socket.File() + if err != nil { + return "", err + } + defer f.Close() + + cmd.ExtraFiles = append(cmd.ExtraFiles, f) + + if err := cmd.Start(); err != nil { + return "", err + } + defer func() { + if err != nil { + cmd.Process.Kill() + } + }() + // make sure to wait after start + go cmd.Wait() + if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil { + return "", err + } + if err := shim.WriteAddress("address", address); err != nil { + return "", err + } + if err := shim.SetScore(cmd.Process.Pid); err != nil { + return "", fmt.Errorf("failed to set OOM Score on shim: %w", err) + } + return address, nil +} + +func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{ + Force: true, + }); err != nil { + log.L.Printf("failed to remove runc container: %v", err) + } + if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + return &taskAPI.DeleteResponse{ + ExitedAt: time.Now(), + ExitStatus: 128 + uint32(unix.SIGKILL), + }, nil +} + +func (s *service) readRuntime(path string) (string, error) { + data, err := ioutil.ReadFile(filepath.Join(path, "runtime")) + if err != nil { + return "", err + } + return string(data), nil +} + +func (s *service) writeRuntime(path, runtime string) error { + return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600) +} + +// Create creates a new initial process and container with the underlying OCI +// runtime. +func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, fmt.Errorf("create namespace: %w", err) + } + + // Read from root for now. + var opts options.Options + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + var path string + switch o := v.(type) { + case *runctypes.CreateOptions: // containerd 1.2.x + opts.IoUid = o.IoUid + opts.IoGid = o.IoGid + opts.ShimCgroup = o.ShimCgroup + case *runctypes.RuncOptions: // containerd 1.2.x + root := proc.RunscRoot + if o.RuntimeRoot != "" { + root = o.RuntimeRoot + } + + opts.BinaryName = o.Runtime + + path = filepath.Join(root, configFile) + if _, err := os.Stat(path); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("stat config file %q: %w", path, err) + } + // A config file in runtime root is not required. + path = "" + } + case *runtimeoptions.Options: // containerd 1.3.x+ + if o.ConfigPath == "" { + break + } + if o.TypeUrl != options.OptionType { + return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl) + } + path = o.ConfigPath + default: + return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl) + } + if path != "" { + if _, err = toml.DecodeFile(path, &opts); err != nil { + return nil, fmt.Errorf("decode config file %q: %w", path, err) + } + } + } + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: opts.BinaryName, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil { + return nil, err + } + defer func() { + if err != nil { + if err := mount.UnmountAll(rootfs, 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + r.Bundle, + filepath.Join(r.Bundle, "work"), + ns, + s.platform, + config, + &opts, + rootfs, + ) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + + // Set up OOM notification on the sandbox's cgroup. This is done on + // sandbox create since the sandbox process will be created here. + pid := process.Pid() + if pid > 0 { + cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid)) + if err != nil { + return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err) + } + if err := s.oomPoller.add(s.id, cg); err != nil { + return nil, fmt.Errorf("add cg to OOM monitor: %w", err) + } + } + s.task = process + s.opts = opts + return &taskAPI.CreateTaskResponse{ + Pid: uint32(process.Pid()), + }, nil + +} + +// Start starts a process. +func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + // TODO: Set the cgroup and oom notifications on restore. + // https://github.com/google/gvisor-containerd-shim/issues/58 + return &taskAPI.StartResponse{ + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + isTask := r.ExecID == "" + if !isTask { + s.mu.Lock() + delete(s.processes, r.ExecID) + s.mu.Unlock() + } + if isTask && s.platform != nil { + s.platform.Close() + } + return &taskAPI.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + p := s.processes[r.ExecID] + s.mu.Unlock() + if p != nil { + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) + } + p = s.task + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{ + ID: r.ExecID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ExecID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resizes the terminal of a process. +func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &taskAPI.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause the container. +func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume the container. +func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill a process with the provided signal. +func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// Pids returns all pids inside the container. +func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &taskAPI.PidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Connect returns shim information such as the shim's pid. +func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + var pid int + if s.task != nil { + pid = s.task.Pid() + } + return &taskAPI.ConnectResponse{ + ShimPid: uint32(os.Getpid()), + TaskPid: uint32(pid), + }, nil +} + +func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { + s.cancel() + os.Exit(0) + return empty, nil +} + +func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + stats, err := rs.Stats(ctx, s.id) + if err != nil { + return nil, err + } + + // gvisor currently (as of 2020-03-03) only returns the total memory + // usage and current PID value[0]. However, we copy the common fields here + // so that future updates will propagate correct information. We're + // using the cgroups.Metrics structure so we're returning the same type + // as runc. + // + // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 + data, err := typeurl.MarshalAny(&cgroups.Metrics{ + CPU: &cgroups.CPUStat{ + Usage: &cgroups.CPUUsage{ + Total: stats.Cpu.Usage.Total, + Kernel: stats.Cpu.Usage.Kernel, + User: stats.Cpu.Usage.User, + PerCPU: stats.Cpu.Usage.Percpu, + }, + Throttling: &cgroups.Throttle{ + Periods: stats.Cpu.Throttling.Periods, + ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, + ThrottledTime: stats.Cpu.Throttling.ThrottledTime, + }, + }, + Memory: &cgroups.MemoryStat{ + Cache: stats.Memory.Cache, + Usage: &cgroups.MemoryEntry{ + Limit: stats.Memory.Usage.Limit, + Usage: stats.Memory.Usage.Usage, + Max: stats.Memory.Usage.Max, + Failcnt: stats.Memory.Usage.Failcnt, + }, + Swap: &cgroups.MemoryEntry{ + Limit: stats.Memory.Swap.Limit, + Usage: stats.Memory.Swap.Usage, + Max: stats.Memory.Swap.Max, + Failcnt: stats.Memory.Swap.Failcnt, + }, + Kernel: &cgroups.MemoryEntry{ + Limit: stats.Memory.Kernel.Limit, + Usage: stats.Memory.Kernel.Usage, + Max: stats.Memory.Kernel.Max, + Failcnt: stats.Memory.Kernel.Failcnt, + }, + KernelTCP: &cgroups.MemoryEntry{ + Limit: stats.Memory.KernelTCP.Limit, + Usage: stats.Memory.KernelTCP.Usage, + Max: stats.Memory.KernelTCP.Max, + Failcnt: stats.Memory.KernelTCP.Failcnt, + }, + }, + Pids: &cgroups.PidsStat{ + Current: stats.Pids.Current, + Limit: stats.Pids.Limit, + }, + }) + if err != nil { + return nil, err + } + return &taskAPI.StatsResponse{ + Stats: data, + }, nil +} + +// Update updates a running container. +func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + p.Wait() + + return &taskAPI.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *service) checkProcesses(e proc.Exit) { + // TODO(random-liu): Add `shouldKillAll` logic if container pid + // namespace is supported. + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &events.TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *service) allProcesses() (o []process.Process) { + s.mu.Lock() + defer s.mu.Unlock() + for _, p := range s.processes { + o = append(o, p) + } + if s.task != nil { + o = append(o, s.task) + } + return o +} + +func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + s.mu.Lock() + p := s.task + s.mu.Unlock() + if p == nil { + return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) + } + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *service) forward(publisher shim.Publisher) { + for e := range s.events { + ctx, cancel := context.WithTimeout(s.context, 5*time.Second) + err := publisher.Publish(ctx, getTopic(e), e) + cancel() + if err != nil { + // Should not happen. + panic(fmt.Errorf("post event: %w", err)) + } + } +} + +func (s *service) getProcess(execID string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + if execID == "" { + return s.task, nil + } + p := s.processes[execID] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) + } + return p, nil +} + +func getTopic(e interface{}) string { + switch e.(type) { + case *events.TaskCreate: + return runtime.TaskCreateEventTopic + case *events.TaskStart: + return runtime.TaskStartEventTopic + case *events.TaskOOM: + return runtime.TaskOOMEventTopic + case *events.TaskExit: + return runtime.TaskExitEventTopic + case *events.TaskDelete: + return runtime.TaskDeleteEventTopic + case *events.TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *events.TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) { + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + runsc.FormatLogPath(r.ID, options.RunscConfig) + runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go new file mode 100644 index 000000000..1800ab90b --- /dev/null +++ b/pkg/shim/v2/service_linux.go @@ -0,0 +1,108 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index e131455f7..ae0fe1522 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -12,6 +12,7 @@ go_library( "sleep_unsafe.go", ], visibility = ["//:sandbox"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index af47e2ba1..1dd11707d 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -379,10 +379,7 @@ func TestRace(t *testing.T) { // TestRaceInOrder tests that multiple wakers can continuously send wake requests to // the sleeper and that the wakers are retrieved in the order asserted. func TestRaceInOrder(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - w := make([]Waker, wakers) + w := make([]Waker, 10000) s := Sleeper{} // Associate each waker and start goroutines that will assert them. @@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) { s.AddWaker(&w[i], i) } go func() { - n := 0 - for n < wakeRequests { - wk := w[n%len(w)] - wk.Assert() - n++ + for i := range w { + w[i].Assert() } }() // Wait for all wake up notifications from all wakers. - for i := 0; i < wakeRequests; i++ { - v, _ := s.Fetch(true) - if got, want := v, i%wakers; got != want { - t.Fatalf("got %d want %d", got, want) + for want := range w { + got, _ := s.Fetch(true) + if got != want { + t.Fatalf("got %d want %d", got, want) } } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index f68c12620..19bce2afb 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. @@ -75,6 +75,8 @@ package sleep import ( "sync/atomic" "unsafe" + + "gvisor.dev/gvisor/pkg/sync" ) const ( @@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { // // This struct is thread-safe, that is, its methods can be called concurrently // by multiple goroutines. +// +// Note, it is not safe to copy a Waker as its fields are modified by value +// (the pointer fields are individually modified with atomic operations). type Waker struct { + _ sync.NoCopy + // s is the sleeper that this waker can wake up. Only one sleeper at a // time is allowed. This field can have three classes of values: // nil -- the waker is not asserted: it either is not associated with diff --git a/pkg/state/decode.go b/pkg/state/decode.go index c9971cdf6..89467ca8e 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -584,10 +584,12 @@ func (ds *decodeState) Load(obj reflect.Value) { }) // Create the root object. - ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ + rootOds := &objectDecodeState{ id: 1, obj: obj, - }) + } + ds.objectsByID = append(ds.objectsByID, rootOds) + ds.pending.PushBack(rootOds) // Read the number of objects. lastID, object, err := ReadHeader(ds.r) diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index cf37aaa49..887f453a9 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -26,12 +26,17 @@ import ( "gvisor.dev/gvisor/pkg/state/wire" ) -func formatRef(x *wire.Ref, graph uint64, html bool) string { +type printer struct { + html bool + typeSpecs map[string]*wire.Type +} + +func (p *printer) formatRef(x *wire.Ref, graph uint64) string { baseRef := fmt.Sprintf("g%dr%d", graph, x.Root) fullRef := baseRef if len(x.Dots) > 0 { // See wire.Ref; Type valid if Dots non-zero. - typ, _ := formatType(x.Type, graph, html) + typ, _ := p.formatType(x.Type, graph) var buf strings.Builder buf.WriteString("(*") buf.WriteString(typ) @@ -51,34 +56,40 @@ func formatRef(x *wire.Ref, graph uint64, html bool) string { buf.WriteString(")") fullRef = buf.String() } - if html { + if p.html { return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef) } return fullRef } -func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { +func (p *printer) formatType(t wire.TypeSpec, graph uint64) (string, bool) { switch x := t.(type) { case wire.TypeID: - base := fmt.Sprintf("g%dt%d", graph, x) - if html { - return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true + tag := fmt.Sprintf("g%dt%d", graph, x) + desc := tag + if spec, ok := p.typeSpecs[tag]; ok { + desc += fmt.Sprintf("=%s", spec.Name) + } else { + desc += "!missing-type-spec" + } + if p.html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", tag, desc), true } - return fmt.Sprintf("%s", base), true + return desc, true case wire.TypeSpecNil: return "", false // Only nil type. case *wire.TypeSpecPointer: - element, _ := formatType(x.Type, graph, html) + element, _ := p.formatType(x.Type, graph) return fmt.Sprintf("(*%s)", element), true case *wire.TypeSpecArray: - element, _ := formatType(x.Type, graph, html) + element, _ := p.formatType(x.Type, graph) return fmt.Sprintf("[%d](%s)", x.Count, element), true case *wire.TypeSpecSlice: - element, _ := formatType(x.Type, graph, html) + element, _ := p.formatType(x.Type, graph) return fmt.Sprintf("([]%s)", element), true case *wire.TypeSpecMap: - key, _ := formatType(x.Key, graph, html) - value, _ := formatType(x.Value, graph, html) + key, _ := p.formatType(x.Key, graph) + value, _ := p.formatType(x.Value, graph) return fmt.Sprintf("(map[%s]%s)", key, value), true default: panic(fmt.Sprintf("unreachable: unknown type %T", t)) @@ -87,7 +98,7 @@ func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { // format formats a single object, for pretty-printing. It also returns whether // the value is a non-zero value. -func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) { +func (p *printer) format(graph uint64, depth int, encoded wire.Object) (string, bool) { switch x := encoded.(type) { case wire.Nil: return "nil", false @@ -98,7 +109,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo case *wire.Complex128: return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 case *wire.Ref: - return formatRef(x, graph, html), x.Root != 0 + return p.formatRef(x, graph), x.Root != 0 case *wire.Type: tabs := "\n" + strings.Repeat("\t", depth) items := make([]string, 0, len(x.Fields)+2) @@ -109,7 +120,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "}") return strings.Join(items, tabs), true // No zero value. case *wire.Slice: - return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0 + return fmt.Sprintf("%s{len:%d,cap:%d}", p.formatRef(&x.Ref, graph), x.Length, x.Capacity), x.Capacity != 0 case *wire.Array: if len(x.Contents) == 0 { return "[]", false @@ -119,7 +130,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "[") tabs := "\n" + strings.Repeat("\t", depth) for i := 0; i < len(x.Contents); i++ { - item, ok := format(graph, depth+1, x.Contents[i], html) + item, ok := p.format(graph, depth+1, x.Contents[i]) if !ok { zeros = append(zeros, fmt.Sprintf("\t%s,", item)) continue @@ -136,7 +147,9 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "]") return strings.Join(items, tabs), len(zeros) < len(x.Contents) case *wire.Struct: - typ, _ := formatType(x.TypeID, graph, html) + tag := fmt.Sprintf("g%dt%d", graph, x.TypeID) + spec, _ := p.typeSpecs[tag] + typ, _ := p.formatType(x.TypeID, graph) if x.Fields() == 0 { return fmt.Sprintf("struct[%s]{}", typ), false } @@ -145,10 +158,15 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo tabs := "\n" + strings.Repeat("\t", depth) allZero := true for i := 0; i < x.Fields(); i++ { - element, ok := format(graph, depth+1, *x.Field(i), html) + var name string + if spec != nil && i < len(spec.Fields) { + name = spec.Fields[i] + } else { + name = fmt.Sprintf("%d", i) + } + element, ok := p.format(graph, depth+1, *x.Field(i)) allZero = allZero && !ok - items = append(items, fmt.Sprintf("\t%d: %s,", i, element)) - i++ + items = append(items, fmt.Sprintf("\t%s: %s,", name, element)) } items = append(items, "}") return strings.Join(items, tabs), !allZero @@ -160,15 +178,15 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "map{") tabs := "\n" + strings.Repeat("\t", depth) for i := 0; i < len(x.Keys); i++ { - key, _ := format(graph, depth+1, x.Keys[i], html) - value, _ := format(graph, depth+1, x.Values[i], html) + key, _ := p.format(graph, depth+1, x.Keys[i]) + value, _ := p.format(graph, depth+1, x.Values[i]) items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) } items = append(items, "}") return strings.Join(items, tabs), true case *wire.Interface: - typ, typOk := formatType(x.Type, graph, html) - element, elementOk := format(graph, depth+1, x.Value, html) + typ, typOk := p.formatType(x.Type, graph) + element, elementOk := p.format(graph, depth+1, x.Value) return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk default: // Must be a primitive; use reflection. @@ -177,11 +195,11 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo } // printStream is the basic print implementation. -func printStream(w io.Writer, r wire.Reader, html bool) (err error) { +func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // current graph ID. var graph uint64 - if html { + if p.html { fmt.Fprintf(w, "<pre>") defer fmt.Fprintf(w, "</pre>") } @@ -196,6 +214,8 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { } }() + p.typeSpecs = make(map[string]*wire.Type) + for { // Find the first object to begin generation. length, object, err := state.ReadHeader(r) @@ -223,18 +243,19 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. var ( - oid uint64 = 1 - tid uint64 = 1 + tid uint64 = 1 + objects []wire.Object ) - for oid <= length { + for oid := uint64(1); oid <= length; { // Unmarshal the object. encoded := wire.Load(r) // Is this a type? - if _, ok := encoded.(*wire.Type); ok { - str, _ := format(graph, 0, encoded, html) + if typ, ok := encoded.(*wire.Type); ok { + str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - if html { + p.typeSpecs[tag] = typ + if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) } @@ -245,17 +266,24 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { continue } + // Otherwise, it is a node. + objects = append(objects, encoded) + oid++ + } + + for i, encoded := range objects { + // oid starts at 1. + oid := i + 1 // Format the node. - str, _ := format(graph, 0, encoded, html) + str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dr%d", graph, oid) - if html { + if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) } if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { return err } - oid++ } } @@ -264,10 +292,10 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { // PrintText reads the stream from r and prints text to w. func PrintText(w io.Writer, r wire.Reader) error { - return printStream(w, r, false /* html */) + return (&printer{}).printStream(w, r) } // PrintHTML reads the stream from r and prints html to w. func PrintHTML(w io.Writer, r wire.Reader) error { - return printStream(w, r, true /* html */) + return (&printer{html: true}).printStream(w, r) } diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go index 1e9794296..3c73ac391 100644 --- a/pkg/state/tests/load_test.go +++ b/pkg/state/tests/load_test.go @@ -20,6 +20,14 @@ import ( func TestLoadHooks(t *testing.T) { runTestCases(t, false, "load-hooks", []interface{}{ + // Root object being a struct. + afterLoadStruct{v: 1}, + valueLoadStruct{v: 1}, + genericContainer{v: &afterLoadStruct{v: 1}}, + genericContainer{v: &valueLoadStruct{v: 1}}, + sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, + sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, + // Root object being a pointer. &afterLoadStruct{v: 1}, &valueLoadStruct{v: 1}, &genericContainer{v: &afterLoadStruct{v: 1}}, diff --git a/pkg/state/types.go b/pkg/state/types.go index 215ef80f8..84aed8732 100644 --- a/pkg/state/types.go +++ b/pkg/state/types.go @@ -107,6 +107,14 @@ func lookupNameFields(typ reflect.Type) (string, []string, bool) { } return name, nil, true } + // Sanity check the type. + if raceEnabled { + if _, ok := reverseTypeDatabase[typ]; !ok { + // The type was not registered? Must be an embedded + // structure or something else. + return "", nil, false + } + } // Extract the name from the object. name := t.StateTypeName() fields := t.StateFields() @@ -313,6 +321,9 @@ var primitiveTypeDatabase = func() map[string]reflect.Type { // globalTypeDatabase is used for dispatching interfaces on decode. var globalTypeDatabase = map[string]reflect.Type{} +// reverseTypeDatabase is a reverse mapping. +var reverseTypeDatabase = map[reflect.Type]string{} + // Register registers a type. // // This must be called on init and only done once. @@ -358,4 +369,7 @@ func Register(t Type) { Failf("conflicting name for %T: matches interfaceType", t) } globalTypeDatabase[name] = typ + if raceEnabled { + reverseTypeDatabase[typ] = name + } } diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index d0d77e19c..68535c3b1 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -33,10 +33,12 @@ go_library( "aliases.go", "memmove_unsafe.go", "mutex_unsafe.go", + "nocopy.go", "norace_unsafe.go", "race_unsafe.go", "rwmutex_unsafe.go", "seqcount.go", + "spin_unsafe.go", "sync.go", ], marshal = False, diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go index 1d7780695..f5e630009 100644 --- a/pkg/sync/memmove_unsafe.go +++ b/pkg/sync/memmove_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go index dc034d561..f4c2e9642 100644 --- a/pkg/sync/mutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.16 +// +build !go1.17 // When updating the build constraint (above), check that syncMutex matches the // standard library sync.Mutex definition. diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go new file mode 100644 index 000000000..722b29501 --- /dev/null +++ b/pkg/sync/nocopy.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +// NoCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type NoCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Unlock() {} diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index 995c0346e..b3b4dee78 100644 --- a/pkg/sync/rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go index eda6fb131..2184cb5ab 100644 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sync/seqatomic_unsafe.go @@ -25,41 +25,35 @@ import ( type Value struct{} // SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value +// with any writer critical sections in seq. +// +//go:nosplit +func SeqAtomicLoad(seq *sync.SeqCount, ptr *Value) Value { for { - epoch := sc.BeginRead() - if sync.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break + if val, ok := SeqAtomicTryLoad(seq, seq.BeginRead(), ptr); ok { + return val } } - return val } // SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns +// in seq initiated by a call to seq.BeginRead() that returned epoch. If the +// read would race with a writer critical section, SeqAtomicTryLoad returns // (unspecified, false). -func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value +// +//go:nosplit +func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (val Value, ok bool) { if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, so it + // can't help us here. Instead, call runtime.memmove directly, which is + // not instrumented by the race detector. sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { + // This is ~40% faster for short reads than going through memmove. val = *ptr } - return val, sc.ReadOk(epoch) + ok = seq.ReadOk(epoch) + return } func init() { diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go index a1e895352..2c5d3df99 100644 --- a/pkg/sync/seqcount.go +++ b/pkg/sync/seqcount.go @@ -8,7 +8,6 @@ package sync import ( "fmt" "reflect" - "runtime" "sync/atomic" ) @@ -43,9 +42,7 @@ type SeqCount struct { } // SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} +type SeqCountEpoch uint32 // We assume that: // @@ -83,12 +80,25 @@ type SeqCountEpoch struct { // using this pattern. Most users of SeqCount will need to use the // SeqAtomicLoad function template in seqatomic.go. func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) + if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 { + return SeqCountEpoch(epoch) + } + return s.beginReadSlow() +} + +func (s *SeqCount) beginReadSlow() SeqCountEpoch { + i := 0 + for { + if canSpin(i) { + i++ + doSpin() + } else { + goyield() + } + if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 { + return SeqCountEpoch(epoch) + } } - return SeqCountEpoch{epoch} } // ReadOk returns true if the reader critical section initiated by a previous @@ -99,7 +109,7 @@ func (s *SeqCount) BeginRead() SeqCountEpoch { // Reader critical sections do not need to be explicitly terminated; the last // call to ReadOk is implicitly the end of the reader critical section. func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val + return atomic.LoadUint32(&s.epoch) == uint32(epoch) } // BeginWrite indicates the beginning of a writer critical section. diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/spin_unsafe.go new file mode 100644 index 000000000..cafb2d065 --- /dev/null +++ b/pkg/sync/spin_unsafe.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.17 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + _ "unsafe" // for go:linkname +) + +//go:linkname canSpin sync.runtime_canSpin +func canSpin(i int) bool + +//go:linkname doSpin sync.runtime_doSpin +func doSpin() + +//go:linkname goyield runtime.goyield +func goyield() diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go index 4bff59e7d..dabf08895 100644 --- a/pkg/syncevent/broadcaster.go +++ b/pkg/syncevent/broadcaster.go @@ -111,7 +111,9 @@ func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID { return id } -// Preconditions: table must not be full. len(table) is a power of 2. +// Preconditions: +// * table must not be full. +// * len(table) is a power of 2. func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) { entry := broadcasterSlot{ receiver: r, diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go index ddffb171a..d3d0f34c5 100644 --- a/pkg/syncevent/source.go +++ b/pkg/syncevent/source.go @@ -19,9 +19,11 @@ type Source interface { // SubscribeEvents causes the Source to notify the given Receiver of the // given subset of events. // - // Preconditions: r != nil. The ReceiverCallback for r must not take locks - // that are ordered prior to the Source; for example, it cannot call any - // Source methods. + // Preconditions: + // * r != nil. + // * The ReceiverCallback for r must not take locks that are ordered + // prior to the Source; for example, it cannot call any Source + // methods. SubscribeEvents(r *Receiver, filter Set) SubscriptionID // UnsubscribeEvents causes the Source to stop notifying the Receiver diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s index 985b56ae5..5e216b045 100644 --- a/pkg/syncevent/waiter_amd64.s +++ b/pkg/syncevent/waiter_amd64.s @@ -16,9 +16,9 @@ // See waiter_noasm_unsafe.go for a description of waiterUnlock. // -// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool +// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 - MOVQ g+0(FP), DI + MOVQ ptr+0(FP), DI MOVQ wg+8(FP), SI MOVQ $·preparingG(SB), AX diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s index 20d7ac23b..f4c06f194 100644 --- a/pkg/syncevent/waiter_arm64.s +++ b/pkg/syncevent/waiter_arm64.s @@ -16,11 +16,11 @@ // See waiter_noasm_unsafe.go for a description of waiterUnlock. // -// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool +// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 MOVD wg+8(FP), R0 MOVD $·preparingG(SB), R1 - MOVD g+0(FP), R2 + MOVD ptr+0(FP), R2 again: LDAXR (R0), R3 CMP R1, R3 diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go index 0995e9053..19d6b0b15 100644 --- a/pkg/syncevent/waiter_asm_unsafe.go +++ b/pkg/syncevent/waiter_asm_unsafe.go @@ -21,4 +21,4 @@ import ( ) // See waiter_noasm_unsafe.go for a description of waiterUnlock. -func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool +func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go index 1c4b0e39a..0f74a689c 100644 --- a/pkg/syncevent/waiter_noasm_unsafe.go +++ b/pkg/syncevent/waiter_noasm_unsafe.go @@ -32,8 +32,8 @@ import ( // should be aborted. // //go:nosplit -func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool { +func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool { // The only way this CAS can fail is if a call to Waiter.NotifyPending() // has replaced *wg with nil, in which case we should not sleep. - return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g) + return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), ptr) } diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go index ad271e1a0..518f18479 100644 --- a/pkg/syncevent/waiter_unsafe.go +++ b/pkg/syncevent/waiter_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 8ff922c69..5ae10939d 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -22,7 +22,7 @@ import ( // Mapping for tcpip.Error types. var ( ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL) + ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index c73072c42..f516c8e46 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -33,6 +33,7 @@ var ( EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) + ECONNABORTED = error(syscall.ECONNABORTED) ECONNREFUSED = error(syscall.ECONNREFUSED) ECONNRESET = error(syscall.ECONNRESET) EDEADLK = error(syscall.EDEADLK) @@ -61,6 +62,7 @@ var ( ENOMEM = error(syscall.ENOMEM) ENOSPC = error(syscall.ENOSPC) ENOSYS = error(syscall.ENOSYS) + ENOTCONN = error(syscall.ENOTCONN) ENOTDIR = error(syscall.ENOTDIR) ENOTEMPTY = error(syscall.ENOTEMPTY) ENOTSOCK = error(syscall.ENOTSOCK) @@ -152,6 +154,73 @@ func ConvertIntr(err, intr error) error { return err } +// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel +// include/linux/errno.h. These errnos are never returned to userspace +// directly, but are used to communicate the expected behavior of an +// interrupted syscall from the syscall to signal handling. +type SyscallRestartErrno int + +// These numeric values are significant because ptrace syscall exit tracing can +// observe them. +// +// For all of the following errnos, if the syscall is not interrupted by a +// signal delivered to a user handler, the syscall is restarted. +const ( + // ERESTARTSYS is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler without SA_RESTART set, and restarted otherwise. + ERESTARTSYS = SyscallRestartErrno(512) + + // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it + // should always be restarted. + ERESTARTNOINTR = SyscallRestartErrno(513) + + // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler, and restarted otherwise. + ERESTARTNOHAND = SyscallRestartErrno(514) + + // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate + // that it should be restarted using a custom function. The interrupted + // syscall must register a custom restart function by calling + // Task.SetRestartSyscallFn. + ERESTART_RESTARTBLOCK = SyscallRestartErrno(516) +) + +// Error implements error.Error. +func (e SyscallRestartErrno) Error() string { + // Descriptions are borrowed from strace. + switch e { + case ERESTARTSYS: + return "to be restarted if SA_RESTART is set" + case ERESTARTNOINTR: + return "to be restarted" + case ERESTARTNOHAND: + return "to be restarted if no handler" + case ERESTART_RESTARTBLOCK: + return "interrupted by signal" + default: + return "(unknown interrupt error)" + } +} + +// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by +// rv, the value in a syscall return register. +func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) { + switch int(rv) { + case -int(ERESTARTSYS): + return ERESTARTSYS, true + case -int(ERESTARTNOINTR): + return ERESTARTNOINTR, true + case -int(ERESTARTNOHAND): + return ERESTARTNOHAND, true + case -int(ERESTART_RESTARTBLOCK): + return ERESTART_RESTARTBLOCK, true + default: + return 0, false + } +} + func init() { AddErrorTranslation(ErrWouldBlock, syscall.EWOULDBLOCK) AddErrorTranslation(ErrInterrupted, syscall.EINTR) diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go index 29719752e..7036467c4 100644 --- a/pkg/syserror/syserror_test.go +++ b/pkg/syserror/syserror_test.go @@ -24,27 +24,20 @@ import ( var globalError error -func returnErrnoAsError() error { - return syscall.EINVAL -} - -func returnError() error { - return syserror.EINVAL -} - -func BenchmarkReturnErrnoAsError(b *testing.B) { +func BenchmarkAssignErrno(b *testing.B) { for i := b.N; i > 0; i-- { - returnErrnoAsError() + globalError = syscall.EINVAL } } -func BenchmarkReturnError(b *testing.B) { +func BenchmarkAssignError(b *testing.B) { for i := b.N; i > 0; i-- { - returnError() + globalError = syserror.EINVAL } } func BenchmarkCompareErrno(b *testing.B) { + globalError = syscall.EAGAIN j := 0 for i := b.N; i > 0; i-- { if globalError == syscall.EINVAL { @@ -54,6 +47,7 @@ func BenchmarkCompareErrno(b *testing.B) { } func BenchmarkCompareError(b *testing.B) { + globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { if globalError == syserror.EINVAL { @@ -63,6 +57,7 @@ func BenchmarkCompareError(b *testing.B) { } func BenchmarkSwitchErrno(b *testing.B) { + globalError = syscall.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { @@ -77,6 +72,7 @@ func BenchmarkSwitchErrno(b *testing.B) { } func BenchmarkSwitchError(b *testing.B) { + globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index d82ed5205..4f551cd92 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { // Accept implements net.Conn.Accept. func (l *TCPListener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() + n, wq, err := l.ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { defer l.wq.EventUnregister(&waitEntry) for { - n, wq, err = l.ep.Accept() + n, wq, err = l.ep.Accept(nil) if err != tcpip.ErrWouldBlock { break @@ -541,7 +541,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, case <-notifyCh: } - err = ep.GetSockOpt(tcpip.ErrorOption{}) + err = ep.LastError() } if err != nil { ep.Close() diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 3c552988a..12b061def 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -61,8 +61,8 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { @@ -104,7 +104,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er err = ep.Connect(addr) if err == tcpip.ErrConnectStarted { <-ch - err = ep.GetSockOpt(tcpip.ErrorOption{}) + err = ep.LastError() } if err != nil { return nil, err diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 563bc78ea..c326fab54 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -14,6 +14,8 @@ go_library( go_test( name = "buffer_test", size = "small", - srcs = ["view_test.go"], + srcs = [ + "view_test.go", + ], library = ":buffer", ) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 9a3c5d6c3..8db70a700 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -65,6 +65,16 @@ func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) } +// IsEmpty returns whether v is of length zero. +func (v View) IsEmpty() bool { + return len(v) == 0 +} + +// Size returns the length of v. +func (v View) Size() int { + return len(v) +} + // VectorisedView is a vectorised version of View using non contiguous memory. // It supports all the convenience methods supported by View. // @@ -74,8 +84,8 @@ type VectorisedView struct { size int } -// NewVectorisedView creates a new vectorised view from an already-allocated slice -// of View and sets its size. +// NewVectorisedView creates a new vectorised view from an already-allocated +// slice of View and sets its size. func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } @@ -160,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) { } // Clone returns a clone of this VectorisedView. -// If the buffer argument is large enough to contain all the Views of this VectorisedView, -// the method will avoid allocations and use the buffer to store the Views of the clone. +// If the buffer argument is large enough to contain all the Views of this +// VectorisedView, the method will avoid allocations and use the buffer to +// store the Views of the clone. func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } @@ -199,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) { return newFirst, true } -// Size returns the size in bytes of the entire content stored in the vectorised view. +// Size returns the size in bytes of the entire content stored in the +// vectorised view. func (vv *VectorisedView) Size() int { return vv.size } @@ -212,6 +224,12 @@ func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } + return vv.ToOwnedView() +} + +// ToOwnedView returns a single view containing the content of the vectorised +// view that vv does not own. +func (vv *VectorisedView) ToOwnedView() View { u := make([]byte, 0, vv.size) for _, v := range vv.views { u = append(u, v...) diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index ed434807f..c984470e6 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -12,5 +12,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index ee264b726..d4d785cca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -117,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker { v = ip.HopLimit() } if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) + } + } +} + +// IPFullLength creates a checker for the full IP packet length. The +// expected size is checked against both the Total Length in the +// header and the number of bytes received. +func IPFullLength(packetLength uint16) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + var v uint16 + var l uint16 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TotalLength() + l = uint16(len(ip)) + case header.IPv6: + v = ip.PayloadLength() + header.IPv6FixedHeaderSize + l = uint16(len(ip)) + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) + } + if l != packetLength { + t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) + } + if v != packetLength { + t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) + } + } +} + +// IPv4HeaderLength creates a checker that checks the IPv4 Header length. +func IPv4HeaderLength(headerLength int) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + switch ip := h[0].(type) { + case header.IPv4: + if hl := ip.HeaderLength(); hl != uint8(headerLength) { + t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) + } + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) } } } // PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { +func PayloadLen(payloadLength int) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) + if l := len(h[0].Payload()); l != payloadLength { + t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) + } + } +} + +// IPv4Options returns a checker that checks the options in an IPv4 packet. +func IPv4Options(want []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + options := ip.Options() + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(options) == 0 { + return + } + if diff := cmp.Diff(want, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) } } } @@ -138,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) } } } @@ -153,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) } } } @@ -169,10 +234,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTClass { - t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) - } - if got := cm.TClass; got != want { - t.Fatalf("got cm.TClass = %d, want %d", got, want) + t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) + } else if got := cm.TClass; got != want { + t.Errorf("got cm.TClass = %d, want %d", got, want) } } } @@ -182,10 +246,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) + } else if got := cm.TOS; got != want { + t.Errorf("got cm.TOS = %d, want %d", got, want) } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) + } +} + +// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in +// ControlMessages. +func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPPacketInfo { + t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) + } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { + t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) } } } @@ -196,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker { t.Helper() if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) } } } @@ -222,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { t.Helper() if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) @@ -249,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -285,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) @@ -304,7 +380,7 @@ func SrcPort(port uint16) TransportChecker { t.Helper() if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got = %d, want = %d", p, port) } } } @@ -315,7 +391,7 @@ func DstPort(port uint16) TransportChecker { t.Helper() if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got = %d, want = %d", p, port) } } } @@ -327,7 +403,7 @@ func NoChecksum(noChecksum bool) TransportChecker { udp, ok := h.(header.UDP) if !ok { - return + t.Fatalf("UDP header not found in h: %T", h) } if b := udp.Checksum() == 0; b != noChecksum { @@ -336,50 +412,84 @@ func NoChecksum(noChecksum bool) TransportChecker { } } -// SeqNum creates a checker that checks the sequence number. -func SeqNum(seq uint32) TransportChecker { +// TCPSeqNum creates a checker that checks the sequence number. +func TCPSeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) } } } -// AckNum creates a checker that checks the ack number. -func AckNum(seq uint32) TransportChecker { +// TCPAckNum creates a checker that checks the ack number. +func TCPAckNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got = %d, want = %d", s, seq) } } } -// Window creates a checker that checks the tcp window. -func Window(window uint16) TransportChecker { +// TCPWindow creates a checker that checks the tcp window. +func TCPWindow(window uint16) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in hdr : %T", h) } if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) + t.Errorf("Bad window, got %d, want %d", w, window) + } + } +} + +// TCPWindowGreaterThanEq creates a checker that checks that the TCP window +// is greater than or equal to the provided value. +func TCPWindowGreaterThanEq(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + t.Fatalf("TCP header not found in h: %T", h) + } + + if w := tcp.WindowSize(); w < window { + t.Errorf("Bad window, got %d, want > %d", w, window) + } + } +} + +// TCPWindowLessThanEq creates a checker that checks that the tcp window +// is less than or equal to the provided value. +func TCPWindowLessThanEq(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + t.Fatalf("TCP header not found in h: %T", h) + } + + if w := tcp.WindowSize(); w > window { + t.Errorf("Bad window, got %d, want < %d", w, window) } } } @@ -391,7 +501,7 @@ func TCPFlags(flags uint8) TransportChecker { tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if f := tcp.Flags(); f != flags { @@ -408,7 +518,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if f := tcp.Flags(); (f & mask) != (flags & mask) { @@ -446,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) } foundMSS = true i += 4 @@ -456,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } v := int(opts[i+2]) if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) } foundWS = true i += 3 @@ -505,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { t.Errorf("SACKPermitted option not found. Options: %x", opts) @@ -543,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -563,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) } } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. +// TCPNoSACKBlockChecker creates a checker that verifies that the segment does +// not contain any SACK blocks in the TCP options. func TCPNoSACKBlockChecker() TransportChecker { return TCPSACKBlockChecker(nil) } @@ -633,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) } } } @@ -649,8 +759,8 @@ func Payload(want []byte) TransportChecker { } } -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and -// potentially additional ICMPv4 header fields. +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 +// and potentially additional ICMPv4 header fields. func ICMPv4(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -678,25 +788,91 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } // ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want byte) TransportChecker { +func ICMPv4Code(want header.ICMPv4Code) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. +func ICMPv4Ident(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Ident(); got != want { + t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. +func ICMPv4Seq(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Sequence(); got != want { + t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. +// This assumes that the payload exactly makes up the rest of the slice. +func ICMPv4Checksum() TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + heldChecksum := icmpv4.Checksum() + icmpv4.SetChecksum(0) + newChecksum := ^header.Checksum(icmpv4, 0) + icmpv4.SetChecksum(heldChecksum) + if heldChecksum != newChecksum { + t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) + } + } +} + +// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. +func ICMPv4Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + payload := icmpv4.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } } } @@ -736,25 +912,25 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } // ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want byte) TransportChecker { +func ICMPv6Code(want header.ICMPv6Code) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) } } } diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD new file mode 100644 index 000000000..114d43df3 --- /dev/null +++ b/pkg/tcpip/faketime/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "faketime", + srcs = ["faketime.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@com_github_dpjacques_clockwork//:go_default_library", + ], +) + +go_test( + name = "faketime_test", + size = "small", + srcs = [ + "faketime_test.go", + ], + deps = [ + "//pkg/tcpip/faketime", + ], +) diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go new file mode 100644 index 000000000..f7a4fbde1 --- /dev/null +++ b/pkg/tcpip/faketime/faketime.go @@ -0,0 +1,236 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package faketime provides a fake clock that implements tcpip.Clock interface. +package faketime + +import ( + "container/heap" + "sync" + "time" + + "github.com/dpjacques/clockwork" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// NullClock implements a clock that never advances. +type NullClock struct{} + +var _ tcpip.Clock = (*NullClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (*NullClock) NowNanoseconds() int64 { + return 0 +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (*NullClock) NowMonotonic() int64 { + return 0 +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { + return nil +} + +// ManualClock implements tcpip.Clock and only advances manually with Advance +// method. +type ManualClock struct { + clock clockwork.FakeClock + + // mu protects the fields below. + mu sync.RWMutex + + // times is min-heap of times. A heap is used for quick retrieval of the next + // upcoming time of scheduled work. + times *timeHeap + + // waitGroups stores one WaitGroup for all work scheduled to execute at the + // same time via AfterFunc. This allows parallel execution of all functions + // passed to AfterFunc scheduled for the same time. + waitGroups map[time.Time]*sync.WaitGroup +} + +// NewManualClock creates a new ManualClock instance. +func NewManualClock() *ManualClock { + return &ManualClock{ + clock: clockwork.NewFakeClock(), + times: &timeHeap{}, + waitGroups: make(map[time.Time]*sync.WaitGroup), + } +} + +var _ tcpip.Clock = (*ManualClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (mc *ManualClock) NowNanoseconds() int64 { + return mc.clock.Now().UnixNano() +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (mc *ManualClock) NowMonotonic() int64 { + return mc.NowNanoseconds() +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := mc.clock.Now().Add(d) + wg := mc.addWait(until) + return &manualTimer{ + clock: mc, + until: until, + timer: mc.clock.AfterFunc(d, func() { + defer wg.Done() + f() + }), + } +} + +// addWait adds an additional wait to the WaitGroup for parallel execution of +// all work scheduled for t. Returns a reference to the WaitGroup modified. +func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { + mc.mu.RLock() + wg, ok := mc.waitGroups[t] + mc.mu.RUnlock() + + if ok { + wg.Add(1) + return wg + } + + mc.mu.Lock() + heap.Push(mc.times, t) + mc.mu.Unlock() + + wg = &sync.WaitGroup{} + wg.Add(1) + + mc.mu.Lock() + mc.waitGroups[t] = wg + mc.mu.Unlock() + + return wg +} + +// removeWait removes a wait from the WaitGroup for parallel execution of all +// work scheduled for t. +func (mc *ManualClock) removeWait(t time.Time) { + mc.mu.RLock() + defer mc.mu.RUnlock() + + wg := mc.waitGroups[t] + wg.Done() +} + +// Advance executes all work that have been scheduled to execute within d from +// the current time. Blocks until all work has completed execution. +func (mc *ManualClock) Advance(d time.Duration) { + // Block until all the work is done + until := mc.clock.Now().Add(d) + for { + mc.mu.Lock() + if mc.times.Len() == 0 { + mc.mu.Unlock() + break + } + + t := heap.Pop(mc.times).(time.Time) + if t.After(until) { + // No work to do + heap.Push(mc.times, t) + mc.mu.Unlock() + break + } + mc.mu.Unlock() + + diff := t.Sub(mc.clock.Now()) + mc.clock.Advance(diff) + + mc.mu.RLock() + wg := mc.waitGroups[t] + mc.mu.RUnlock() + + wg.Wait() + + mc.mu.Lock() + delete(mc.waitGroups, t) + mc.mu.Unlock() + } + if now := mc.clock.Now(); until.After(now) { + mc.clock.Advance(until.Sub(now)) + } +} + +type manualTimer struct { + clock *ManualClock + timer clockwork.Timer + + mu sync.RWMutex + until time.Time +} + +var _ tcpip.Timer = (*manualTimer)(nil) + +// Reset implements tcpip.Timer.Reset. +func (t *manualTimer) Reset(d time.Duration) { + if !t.timer.Reset(d) { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + t.clock.removeWait(t.until) + t.until = t.clock.clock.Now().Add(d) + t.clock.addWait(t.until) +} + +// Stop implements tcpip.Timer.Stop. +func (t *manualTimer) Stop() bool { + if !t.timer.Stop() { + return false + } + + t.mu.RLock() + defer t.mu.RUnlock() + + t.clock.removeWait(t.until) + return true +} + +type timeHeap []time.Time + +var _ heap.Interface = (*timeHeap)(nil) + +func (h timeHeap) Len() int { + return len(h) +} + +func (h timeHeap) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +func (h timeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *timeHeap) Push(x interface{}) { + *h = append(*h, x.(time.Time)) +} + +func (h *timeHeap) Pop() interface{} { + last := (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + return last +} diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go new file mode 100644 index 000000000..c2704df2c --- /dev/null +++ b/pkg/tcpip/faketime/faketime_test.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package faketime_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" +) + +func TestManualClockAdvance(t *testing.T) { + const timeout = time.Millisecond + clock := faketime.NewManualClock() + start := clock.NowMonotonic() + clock.Advance(timeout) + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + t.Errorf("got = %d, want = %d", got, want) + } +} + +func TestManualClockAfterFunc(t *testing.T) { + const ( + timeout1 = time.Millisecond // timeout for counter1 + timeout2 = 2 * time.Millisecond // timeout for counter2 + ) + tests := []struct { + name string + advance time.Duration + wantCounter1 int + wantCounter2 int + }{ + { + name: "before timeout1", + advance: timeout1 - 1, + wantCounter1: 0, + wantCounter2: 0, + }, + { + name: "timeout1", + advance: timeout1, + wantCounter1: 1, + wantCounter2: 0, + }, + { + name: "timeout2", + advance: timeout2, + wantCounter1: 1, + wantCounter2: 1, + }, + { + name: "after timeout2", + advance: timeout2 + 1, + wantCounter1: 1, + wantCounter2: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + counter1 := 0 + counter2 := 0 + clock.AfterFunc(timeout1, func() { + counter1++ + }) + clock.AfterFunc(timeout2, func() { + counter2++ + }) + start := clock.NowMonotonic() + clock.Advance(test.advance) + if got, want := counter1, test.wantCounter1; got != want { + t.Errorf("got counter1 = %d, want = %d", got, want) + } + if got, want := counter2, test.wantCounter2; got != want { + t.Errorf("got counter2 = %d, want = %d", got, want) + } + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + t.Errorf("got elapsed = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0cde694dc..d87797617 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -48,7 +48,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -64,6 +64,6 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index b1e92d2d7..eaface8cb 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -53,6 +53,10 @@ const ( // (all bits set to 0). unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + // EthernetBroadcastAddress is an ethernet address that addresses every node + // on a local link. + EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff") + // unicastMulticastFlagMask is the mask of the least significant bit in // the first octet (in network byte order) of an ethernet address that // determines whether the ethernet address is a unicast or multicast. If diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 7908c5744..504408878 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -31,6 +31,27 @@ const ( // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 + // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an + // errant packet's transport layer that an ICMP error type packet should + // attempt to send as per RFC 792 (see each type) and RFC 1122 + // section 3.2.2 which states: + // Every ICMP error message includes the Internet header and at + // least the first 8 data octets of the datagram that triggered + // the error; more than 8 octets MAY be sent; this header and data + // MUST be unchanged from the received datagram. + // + // RFC 792 shows: + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Code | Checksum | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | unused | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Internet Header + 64 bits of Original Data Datagram | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ICMPv4MinimumErrorPayloadSize = 8 + // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 @@ -39,21 +60,28 @@ const ( icmpv4ChecksumOffset = 2 // icmpv4MTUOffset is the offset of the MTU field - // in a ICMPv4FragmentationNeeded message. + // in an ICMPv4FragmentationNeeded message. icmpv4MTUOffset = 6 // icmpv4IdentOffset is the offset of the ident field - // in a ICMPv4EchoRequest/Reply message. + // in an ICMPv4EchoRequest/Reply message. icmpv4IdentOffset = 4 + // icmpv4PointerOffset is the offset of the pointer field + // in an ICMPv4ParamProblem message. + icmpv4PointerOffset = 4 + // icmpv4SequenceOffset is the offset of the sequence field - // in a ICMPv4EchoRequest/Reply message. + // in an ICMPv4EchoRequest/Reply message. icmpv4SequenceOffset = 6 ) // ICMPv4Type is the ICMP type field described in RFC 792. type ICMPv4Type byte +// ICMPv4Code is the ICMP code field described in RFC 792. +type ICMPv4Code byte + // Typical values of ICMPv4Type defined in RFC 792. const ( ICMPv4EchoReply ICMPv4Type = 0 @@ -69,13 +97,23 @@ const ( ICMPv4InfoReply ICMPv4Type = 16 ) -// Values for ICMP code as defined in RFC 792. +// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. const ( - ICMPv4TTLExceeded = 0 - ICMPv4PortUnreachable = 3 - ICMPv4FragmentationNeeded = 4 + ICMPv4TTLExceeded ICMPv4Code = 0 ) +// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. +const ( + ICMPv4NetUnreachable ICMPv4Code = 0 + ICMPv4HostUnreachable ICMPv4Code = 1 + ICMPv4ProtoUnreachable ICMPv4Code = 2 + ICMPv4PortUnreachable ICMPv4Code = 3 + ICMPv4FragmentationNeeded ICMPv4Code = 4 +) + +// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed. +const ICMPv4UnusedCode ICMPv4Code = 0 + // Type is the ICMP type field. func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } @@ -83,10 +121,10 @@ func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } // Code is the ICMP code field. Its meaning depends on the value of Type. -func (b ICMPv4) Code() byte { return b[1] } +func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. -func (b ICMPv4) SetCode(c byte) { b[1] = c } +func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index c7ee2de57..6be31beeb 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -54,9 +54,17 @@ const ( // address. ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize - // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. + // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet. ICMPv6EchoMinimumSize = 8 + // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header, + // as per RFC 4443, Apendix A, item 4 and the errata. + // ... all ICMP error messages shall have exactly + // 32 bits of type-specific data, so that receivers can reliably find + // the embedded invoking packet even when they don't recognize the + // ICMP message Type. + ICMPv6ErrorHeaderSize = 8 + // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize @@ -69,6 +77,10 @@ const ( // in an ICMPv6 message. icmpv6ChecksumOffset = 2 + // icmpv6PointerOffset is the offset of the pointer + // in an ICMPv6 Parameter problem message. + icmpv6PointerOffset = 4 + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 // PacketTooBig message. icmpv6MTUOffset = 4 @@ -89,10 +101,10 @@ const ( NDPHopLimit = 255 ) -// ICMPv6Type is the ICMP type field described in RFC 4443 and friends. +// ICMPv6Type is the ICMP type field described in RFC 4443. type ICMPv6Type byte -// Typical values of ICMPv6Type defined in RFC 4443. +// Values for use in the Type field of ICMPv6 packet from RFC 4433. const ( ICMPv6DstUnreachable ICMPv6Type = 1 ICMPv6PacketTooBig ICMPv6Type = 2 @@ -110,11 +122,54 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP code as defined in RFC 4443. +// IsErrorType returns true if the receiver is an ICMP error type. +func (typ ICMPv6Type) IsErrorType() bool { + // Per RFC 4443 section 2.1: + // ICMPv6 messages are grouped into two classes: error messages and + // informational messages. Error messages are identified as such by a + // zero in the high-order bit of their message Type field values. Thus, + // error messages have message types from 0 to 127; informational + // messages have message types from 128 to 255. + return typ&0x80 == 0 +} + +// ICMPv6Code is the ICMP Code field described in RFC 4443. +type ICMPv6Code byte + +// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443 +// section 3.1. +const ( + ICMPv6NetworkUnreachable ICMPv6Code = 0 + ICMPv6Prohibited ICMPv6Code = 1 + ICMPv6BeyondScope ICMPv6Code = 2 + ICMPv6AddressUnreachable ICMPv6Code = 3 + ICMPv6PortUnreachable ICMPv6Code = 4 + ICMPv6Policy ICMPv6Code = 5 + ICMPv6RejectRoute ICMPv6Code = 6 +) + +// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3. +const ( + ICMPv6HopLimitExceeded ICMPv6Code = 0 + ICMPv6ReassemblyTimeout ICMPv6Code = 1 +) + +// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. const ( - ICMPv6PortUnreachable = 4 + // ICMPv6ErroneousHeader indicates an erroneous header field was encountered. + ICMPv6ErroneousHeader ICMPv6Code = 0 + + // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered. + ICMPv6UnknownHeader ICMPv6Code = 1 + + // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered. + ICMPv6UnknownOption ICMPv6Code = 2 ) +// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use +// the code field. (Types not mentioned above.) +const ICMPv6UnusedCode ICMPv6Code = 0 + // Type is the ICMP type field. func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } @@ -122,10 +177,20 @@ func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } // Code is the ICMP code field. Its meaning depends on the value of Type. -func (b ICMPv6) Code() byte { return b[1] } +func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. -func (b ICMPv6) SetCode(c byte) { b[1] = c } +func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } + +// TypeSpecific returns the type specific data field. +func (b ICMPv6) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +// SetTypeSpecific sets the type specific data field. +func (b ICMPv6) SetTypeSpecific(val uint32) { + binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) +} // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 62ac932bb..4c6e4be64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,10 +16,29 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" ) +// RFC 971 defines the fields of the IPv4 header on page 11 using the following +// diagram: ("Figure 4") +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Identification |Flags| Fragment Offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Time to Live | Protocol | Header Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Source Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Destination Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options | Padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// const ( versIHL = 0 tos = 1 @@ -33,6 +52,7 @@ const ( checksum = 10 srcAddr = 12 dstAddr = 16 + options = 20 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -76,11 +96,13 @@ type IPv4Fields struct { // IPv4 represents an ipv4 header stored in a byte array. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. -// Always call IsValid() to validate an instance of IPv4 before using other methods. +// Always call IsValid() to validate an instance of IPv4 before using other +// methods. type IPv4 []byte const ( - // IPv4MinimumSize is the minimum size of a valid IPv4 packet. + // IPv4MinimumSize is the minimum size of a valid IPv4 packet; + // i.e. a packet header with no options. IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given @@ -88,6 +110,16 @@ const ( // units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. + // + // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 + // header size). But RFC 791 section 3.2 discusses the design of the IPv4 + // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of + // 65,536 octets. Note that this is consistent with the the datagram total + // length field (of course, the header is counted in the total length and not + // in the fragments)." + IPv4MaximumPayloadSize = 65536 + // MinIPFragmentPayloadSize is the minimum number of payload bytes that // the first fragment must carry when an IPv4 packet is fragmented. MinIPFragmentPayloadSize = 8 @@ -101,6 +133,11 @@ const ( // IPv4Version is the version of the ipv4 protocol. IPv4Version = 4 + // IPv4AllSystems is the all systems IPv4 multicast address as per + // IANA's IPv4 Multicast Address Space Registry. See + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml. + IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01" + // IPv4Broadcast is the broadcast address of the IPv4 procotol. IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" @@ -135,13 +172,44 @@ func IPVersion(b []byte) int { if len(b) < versIHL+1 { return -1 } - return int(b[versIHL] >> 4) + return int(b[versIHL] >> ipVersionShift) } +// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits +// of the first byte, and is counted in multiples of 4 bytes. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// (...) +// Version: 4 bits +// The Version field indicates the format of the internet header. This +// document describes version 4. +// +// IHL: 4 bits +// Internet Header Length is the length of the internet header in 32 +// bit words, and thus points to the beginning of the data. Note that +// the minimum value for a correct header is 5. +// +const ( + ipVersionShift = 4 + ipIHLMask = 0x0f + IPv4IHLStride = 4 +) + // HeaderLength returns the value of the "header length" field of the ipv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { - return (b[versIHL] & 0xf) * 4 + return (b[versIHL] & ipIHLMask) * IPv4IHLStride +} + +// SetHeaderLength sets the value of the "Internet Header Length" field. +func (b IPv4) SetHeaderLength(hdrLen uint8) { + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize)) + } + b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } // ID returns the value of the identifier field of the ipv4 header. @@ -195,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// Options returns a a buffer holding the options. +func (b IPv4) Options() []byte { + hdrLen := b.HeaderLength() + return b[options:hdrLen:hdrLen] +} + // TransportProtocol implements Network.TransportProtocol. func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { return tcpip.TransportProtocolNumber(b.Protocol()) @@ -220,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } +// SetTTL sets the "Time to Live" field of the IPv4 header. +func (b IPv4) SetTTL(v byte) { + b[ttl] = v +} + // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) @@ -260,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 { // Encode encodes all the fields of the ipv4 header. func (b IPv4) Encode(i *IPv4Fields) { - b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b.SetHeaderLength(i.IHL) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) @@ -310,3 +389,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool { } return (addr[0] & 0xf0) == 0xe0 } + +// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback +// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3. +func IsV4LoopbackAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return addr[0] == 0x7f +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 4f367fe4c..ef454b313 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -34,6 +34,9 @@ const ( hopLimit = 7 v6SrcAddr = 8 v6DstAddr = v6SrcAddr + IPv6AddressSize + + // IPv6FixedHeaderSize is the size of the fixed header. + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -69,11 +72,15 @@ type IPv6 []byte const ( // IPv6MinimumSize is the minimum size of a valid IPv6 packet. - IPv6MinimumSize = 40 + IPv6MinimumSize = IPv6FixedHeaderSize // IPv6AddressSize is the size, in bytes, of an IPv6 address. IPv6AddressSize = 16 + // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per + // RFC 8200 Section 4.5. + IPv6MaximumPayloadSize = 65535 + // IPv6ProtocolNumber is IPv6's network protocol number. IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd @@ -98,6 +105,9 @@ const ( // section 5. IPv6MinimumMTU = 1280 + // IPv6Loopback is the IPv6 Loopback address. + IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 3499d8399..583c2c5d3 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { // obtained before modification is no longer used. type IPv6OptionsExtHdrOptionsIterator struct { reader bytes.Reader + + // optionOffset is the number of bytes from the first byte of the + // options field to the beginning of the current option. + optionOffset uint32 + + // nextOptionOffset is the offset of the next option. + nextOptionOffset uint32 +} + +// OptionOffset returns the number of bytes parsed while processing the +// option field of the current Extension Header. +func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { + return i.optionOffset } // IPv6OptionUnknownAction is the action that must be taken if the processing @@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} // the options data, or an error occured. func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { for { + i.optionOffset = i.nextOptionOffset temp, err := i.reader.ReadByte() if err != nil { // If we can't read the first byte of a new option, then we know the @@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // know the option does not have Length and Data fields. End processing of // the Pad1 option and continue processing the buffer as a new option. if id == ipv6Pad1ExtHdrOptionIdentifier { + i.nextOptionOffset = i.optionOffset + 1 continue } @@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) } - // Special-case the variable length padding option to avoid a copy. - if id == ipv6PadNExtHdrOptionIdentifier { - // Do we have enough bytes in the reader for the PadN option? - if n := i.reader.Len(); n < int(length) { - // Reset the reader to effectively consume the remaining buffer. - i.reader.Reset(nil) - - // We return the same error as if we failed to read a non-padding option - // so consumers of this iterator don't need to differentiate between - // padding and non-padding options. - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) - } + // Do we have enough bytes in the reader for the next option? + if n := i.reader.Len(); n < int(length) { + // Reset the reader to effectively consume the remaining buffer. + i.reader.Reset(nil) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ + switch id { + case ipv6PadNExtHdrOptionIdentifier: + // Special-case the variable length padding option to avoid a copy. if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil { panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } - - // End processing of the PadN option and continue processing the buffer as - // a new option. continue - } - - bytes := make([]byte, length) - if n, err := io.ReadFull(&i.reader, bytes); err != nil { - // io.ReadFull may return io.EOF if i.reader has been exhausted. We use - // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the - // Length field found in the option. - if err == io.EOF { - err = io.ErrUnexpectedEOF + default: + bytes := make([]byte, length) + if n, err := io.ReadFull(&i.reader, bytes); err != nil { + // io.ReadFull may return io.EOF if i.reader has been exhausted. We use + // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the + // Length field found in the option. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) } - - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } - - return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } } @@ -382,6 +396,29 @@ type IPv6PayloadIterator struct { // Indicates to the iterator that it should return the remaining payload as a // raw payload on the next call to Next. forceRaw bool + + // headerOffset is the offset of the beginning of the current extension + // header starting from the beginning of the fixed header. + headerOffset uint32 + + // parseOffset is the byte offset into the current extension header of the + // field we are currently examining. It can be added to the header offset + // if the absolute offset within the packet is required. + parseOffset uint32 + + // nextOffset is the offset of the next header. + nextOffset uint32 +} + +// HeaderOffset returns the offset to the start of the extension +// header most recently processed. +func (i IPv6PayloadIterator) HeaderOffset() uint32 { + return i.headerOffset +} + +// ParseOffset returns the number of bytes successfully parsed. +func (i IPv6PayloadIterator) ParseOffset() uint32 { + return i.headerOffset + i.parseOffset } // MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing @@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa nextHdrIdentifier: nextHdrIdentifier, payload: payload.Clone(nil), // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. - reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + nextOffset: IPv6FixedHeaderSize, } } @@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Next is unable to return anything because the iterator has reached the end of // the payload, or an error occured. func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + i.headerOffset = i.nextOffset + i.parseOffset = 0 // We could be forced to return i as a raw header when the previous header was // a fragment extension header as the data following the fragment extension // header may not be complete. @@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { return IPv6RoutingExtHdr(bytes), false, nil case IPv6FragmentExtHdrIdentifier: var data [6]byte - // We ignore the returned bytes becauase we know the fragment extension + // We ignore the returned bytes because we know the fragment extension // header specific data will fit in data. nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) if err != nil { @@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP if err != nil { return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) } + i.parseOffset++ var length uint8 length, err = i.reader.ReadByte() i.payload.TrimFront(1) + if err != nil { if fragmentHdr { return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) @@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP length = 0 } + // Make parseOffset point to the first byte of the Extension Header + // specific data. + i.parseOffset++ + + // length is in 8 byte chunks but doesn't include the first one. + // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement + // in section 4.8 for new extension headers at the top of page 24. + // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet + // units, not including the first 8 octets. + i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded if bytes == nil { bytes = make([]byte, bytesLen) diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go index b5540bf66..17a49d4fa 100644 --- a/pkg/tcpip/header/ipversion_test.go +++ b/pkg/tcpip/header/ipversion_test.go @@ -22,7 +22,7 @@ import ( func TestIPv4(t *testing.T) { b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) + b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize}) const want = header.IPv4Version if v := header.IPVersion(b); v != want { diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD new file mode 100644 index 000000000..2adee9288 --- /dev/null +++ b/pkg/tcpip/header/parse/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "parse", + srcs = ["parse.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go new file mode 100644 index 000000000..5ca75c834 --- /dev/null +++ b/pkg/tcpip/header/parse/parse.go @@ -0,0 +1,168 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package parse provides utilities to parse packets. +package parse + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// ARP populates pkt's network header with an ARP header found in +// pkt.Data. +// +// Returns true if the header was successfully parsed. +func ARP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.NetworkHeader().Consume(header.ARPSize) + if ok { + pkt.NetworkProtocolNumber = header.ARPProtocolNumber + } + return ok +} + +// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network +// header with the IPv4 header. +// +// Returns true if the header was successfully parsed. +func IPv4(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return false + } + ipHdr := header.IPv4(hdr) + + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return false + } + ipHdr = header.IPv4(hdr) + + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return true +} + +// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network +// header with the IPv6 header. +func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, 0, 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + var nextHdr tcpip.TransportProtocolNumber + var extensionsSize int + +traverseExtensions: + for { + extHdr, done, err := it.Next() + if err != nil { + break + } + + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + if fragID == 0 && fragOffset == 0 && !fragMore { + fragID = extHdr.ID() + fragOffset = extHdr.FragmentOffset() + fragMore = extHdr.More() + } + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + return nextHdr, fragID, fragOffset, fragMore, true +} + +// UDP parses a UDP packet found in pkt.Data and populates pkt's transport +// header with the UDP header. +// +// Returns true if the header was successfully parsed. +func UDP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + return ok +} + +// TCP parses a TCP packet found in pkt.Data and populates pkt's transport +// header with the TCP header. +// +// Returns true if the header was successfully parsed. +func TCP(pkt *stack.PacketBuffer) bool { + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset + } + + _, ok = pkt.TransportHeader().Consume(hdrLen) + pkt.TransportProtocolNumber = header.TCPProtocolNumber + return ok +} diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 9339d637f..98bdd29db 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "math" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -55,6 +56,10 @@ const ( // UDPMinimumSize is the minimum size of a valid UDP packet. UDPMinimumSize = 8 + // UDPMaximumSize is the maximum size of a valid UDP packet. The length field + // in the UDP header is 16 bits as per RFC 768. + UDPMaximumSize = math.MaxUint16 + // UDPProtocolNumber is UDP's transport protocol number. UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index b8b93e78e..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 20b183da0..c95aef63c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -273,7 +274,9 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: &stack.PacketBuffer{Data: vv}, + Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }), Proto: 0, GSO: nil, } @@ -296,3 +299,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { e.q.RemoveNotify(handle) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index aa6db9aea..10072eac1 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/binary", + "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -36,5 +37,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f34082e1a..975309fc8 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,6 +45,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -385,32 +386,40 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen @@ -430,49 +439,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + builder.Add(vnetHdrBuf) } - if pkt.Data.Size() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Header.View()) - } - if pkt.Header.UsedLength() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + for _, v := range pkt.Views() { + builder.Add(v) } - - return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) + return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { - var ethHdrBuf []byte - iovLen := 0 if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - } - - // Preserve the src address if it's set in the route. - if pkt.EgressRoute.LocalLinkAddress != "" { - ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } - vnetHdr := virtioNetHdr{} var vnetHdrBuf []byte if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen @@ -491,45 +479,18 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc } } vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - iovLen++ } - iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + var builder iovec.Builder + builder.Add(vnetHdrBuf) + for _, v := range pkt.Views() { + builder.Add(v) + } + iovecs := builder.Build() + var mmsgHdr rawfile.MMsgHdr mmsgHdr.Msg.Iov = &iovecs[0] - iovecIdx := 0 - if vnetHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = &vnetHdrBuf[0] - v.Len = uint64(len(vnetHdrBuf)) - iovecIdx++ - } - if ethHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = ðHdrBuf[0] - v.Len = uint64(len(ethHdrBuf)) - iovecIdx++ - } - pktSize := uint64(0) - // Encode L3 Header - v := &iovecs[iovecIdx] - hdr := &pkt.Header - hdrView := hdr.View() - v.Base = &hdrView[0] - v.Len = uint64(len(hdrView)) - pktSize += v.Len - iovecIdx++ - - // Now encode the Transport Payload. - pktViews := pkt.Data.Views() - for i := range pktViews { - vec := &iovecs[iovecIdx] - iovecIdx++ - vec.Base = &pktViews[i][0] - vec.Len = uint64(len(pktViews[i])) - pktSize += vec.Len - } - mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + mmsgHdr.Msg.Iovlen = uint64(len(iovecs)) mmsgHdrs = append(mmsgHdrs, mmsgHdr) } @@ -626,6 +587,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index eaee7e5d7..709f829c8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,9 +44,36 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents *stack.PacketBuffer + Raddr tcpip.LinkAddress + Proto tcpip.NetworkProtocolNumber + Contents *stack.PacketBuffer +} + +type packetContents struct { + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + Data buffer.View +} + +func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { + t.Helper() + if diff := cmp.Diff( + want, got, + cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { + if pk == nil { + return nil + } + return &packetContents{ + LinkHeader: pk.LinkHeader().View(), + NetworkHeader: pk.NetworkHeader().View(), + TransportHeader: pk.TransportHeader().View(), + Data: pk.Data.ToView(), + } + }), + ); diff != "" { + t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) + } } type context struct { @@ -107,6 +135,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin c.ch <- packetInfo{remote, protocol, pkt} } +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNoEthernetProperties(t *testing.T) { c := newContext(t, &Options{MTU: mtu}) defer c.cleanup() @@ -155,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u RemoteLinkAddress: raddr, } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) + // Build payload. + payload := buffer.NewView(plen) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand.Read(payload): %s", err) } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) + // Build packet buffer. + const netHdrLen = 100 + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, + Data: payload.ToVectorisedView(), + }) + pkt.Hash = hash + + // Build header. + b := pkt.NetworkHeader().Push(netHdrLen) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read(b): %s", err) } - want := append(hdr.View(), payload...) + + // Write. + want := append(append(buffer.View(nil), b...), payload...) var gso *stack.GSO if gsoMaxSize != 0 { gso = &stack.GSO{ @@ -179,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - Hash: hash, - }); err != nil { + if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -292,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) { LocalLinkAddress: baddr, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.VectorisedView{}, + }) + if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -327,24 +365,25 @@ func TestDeliverPacket(t *testing.T) { defer c.cleanup() // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) + all := make([]byte, plen) + if _, err := rand.Read(all); err != nil { + t.Fatalf("rand.Read(all): %s", err) } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) + // Make it look like an IPv4 packet. + all[0] = 0x40 + + wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.NewViewFromBytes(all).ToVectorisedView(), + }) + if eth { + hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, Type: proto, }) - all = append(hdr, b...) + all = append(hdr, all...) } // Write packet via the file descriptor. @@ -356,24 +395,15 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: &stack.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, + Raddr: raddr, + Proto: proto, + Contents: wantPkt, } if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + want.Proto = header.IPv4ProtocolNumber + want.Raddr = "" } + checkPacketInfoEqual(t, pi, want) case <-time.After(10 * time.Second): t.Fatalf("Timed out waiting for packet") } @@ -500,3 +530,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) { } } + +// fakeNetworkDispatcher delivers packets to pkts. +type fakeNetworkDispatcher struct { + pkts []*stack.PacketBuffer +} + +func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + d.pkts = append(d.pkts, pkt) +} + +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestDispatchPacketFormat(t *testing.T) { + for _, test := range []struct { + name string + newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) + }{ + { + name: "readVDispatcher", + newDispatcher: newReadVDispatcher, + }, + { + name: "recvMMsgDispatcher", + newDispatcher: newRecvMMsgDispatcher, + }, + } { + t.Run(test.name, func(t *testing.T) { + // Create a socket pair to send/recv. + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Fatal(err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + data := []byte{ + // Ethernet header. + 1, 2, 3, 4, 5, 60, + 1, 2, 3, 4, 5, 61, + 8, 0, + // Mock network header. + 40, 41, 42, 43, + } + err = syscall.Sendmsg(fds[1], data, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + + // Create and run dispatcher once. + sink := &fakeNetworkDispatcher{} + d, err := test.newDispatcher(fds[0], &endpoint{ + hdrSize: header.EthernetMinimumSize, + dispatcher: sink, + }) + if err != nil { + t.Fatal(err) + } + if ok, err := d.dispatch(); !ok || err != nil { + t.Fatalf("d.dispatch() = %v, %v", ok, err) + } + + // Verify packet. + if got, want := len(sink.pkts), 1; got != want { + t.Fatalf("len(sink.pkts) = %d, want %d", got, want) + } + pkt := sink.pkts[0] + if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { + t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) + } + if got, want := pkt.Data.Size(), 4; got != want { + t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 2dfd29aa9..c475dda20 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -18,6 +18,7 @@ package fdbased import ( "encoding/binary" + "fmt" "syscall" "golang.org/x/sys/unix" @@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(pkt) + eth := header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } } - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ - Data: buffer.View(pkt).ToVectorisedView(), - LinkHeader: buffer.View(eth), + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(pkt).ToVectorisedView(), }) + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index f04738cfb..8c3ca86d6 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.allocateViews(BufConfig) n, err := rawfile.BlockingReadv(d.fd, d.iovecs) - if err != nil { + if n == 0 || err != nil { return false, err } if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { // isn't used and it isn't in a view. n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(n, BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(n, BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. @@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(k, int(n), BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(k, int(n), BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 568c6874f..38aa694e4 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -77,16 +77,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) return nil } @@ -98,18 +98,25 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) if !ok { // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } linkHeader := header.Ethernet(hdr) - vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ - Data: vv, - LinkHeader: buffer.View(linkHeader), - }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 82b441b79..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c69d6b7e9..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 0744f66d6..3e4afcdad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -46,14 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) { func TestInjectableEndpointDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -67,13 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) { func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewView(0).ToVectorisedView(), }) + pkt.TransportHeader().Push(1)[0] = 0xFA + packetRoute := stack.Route{RemoteAddress: dstIP} + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index bdd5276ad..2cdb23475 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 2998f9c4f..d40de54df 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -60,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco } } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.mu.Lock() @@ -129,3 +140,13 @@ func (e *Endpoint) GSOMaxSize() uint32 { } return 0 } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index c1a219f02..c1f9d308c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd d.count++ } +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNestedLinkEndpoint(t *testing.T) { const emptyAddress = tcpip.LinkAddress("") @@ -83,7 +87,7 @@ func TestNestedLinkEndpoint(t *testing.T) { t.Error("After attach, nestedEP.IsAttached() = false, want = true") } - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 1 { t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) } @@ -97,7 +101,7 @@ func TestNestedLinkEndpoint(t *testing.T) { } disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 0 { t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) } diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 054c213bc..1d0079bd6 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b5dfb7850..fc1e34fc7 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher @@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } @@ -207,3 +215,13 @@ func (e *endpoint) Wait() { e.wg.Wait() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..6c410c5a6 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,14 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + srcs = [ + "errors_test.go", + ], + library = "rawfile", + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 99313ee25..5db4bf12b 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index a0a873c84..604868fd8 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error // *tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. +// tcpip.ErrInvalidEndpointState (EINVAL). func TranslateErrno(e syscall.Errno) *tcpip.Error { - if err := translations[e]; err != nil { - return err + if e > 0 && e < syscall.Errno(len(translations)) { + if err := translations[e]; err != nil { + return err + } } return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go new file mode 100644 index 000000000..e4cdc66bd --- /dev/null +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestTranslateErrno(t *testing.T) { + for _, test := range []struct { + errno syscall.Errno + translated *tcpip.Error + }{ + { + errno: syscall.Errno(0), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(maxErrno), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(514), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.EEXIST, + translated: tcpip.ErrDuplicateAddress, + }, + } { + got := TranslateErrno(test.errno) + if got != test.translated { + t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + } + } +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 69de6eb3e..f4c32c2da 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -66,38 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { return nil } -// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a -// single syscall. It fails if partial data is written. -func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { - // If there is no second and third buffer, issue a regular write. - if len(b2) == 0 && len(b3) == 0 { - return NonBlockingWrite(fd, b1) - } - - // Build the iovec that represents them and issue a writev syscall. - iovec := [3]syscall.Iovec{ - { - Base: &b1[0], - Len: uint64(len(b1)), - }, - { - Base: &b2[0], - Len: uint64(len(b2)), - }, - } - iovecLen := uintptr(2) - - if len(b3) > 0 { - iovecLen++ - iovec[2].Base = &b3[0] - iovec[2].Len = uint64(len(b3)) - } - +// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. +// It fails if partial data is written. +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { + iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { return TranslateErrno(e) } - return nil } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0374a2441..7fb8a6c49 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -183,27 +183,33 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) - v := pkt.Data.ToView() + views := pkt.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(pkt.Header.View(), v) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -220,10 +226,10 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - v := vv.ToView() + views := vv.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -269,16 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { rxb[i].Size = e.bufferSize } - if n < header.EthernetMinimumSize { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { continue } + eth := header.Ethernet(hdr) // Send packet up the stack. - eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ - Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), - LinkHeader: buffer.View(eth), - }) + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) } // Clean state. @@ -287,3 +295,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 28a2e88ba..22d5c97f1 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) @@ -262,21 +266,23 @@ func TestSimpleSend(t *testing.T) { for iters := 1000; iters > 0; iters-- { func() { + hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) + // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) + hdrBuf := buffer.NewView(hdrLen) randomFill(hdrBuf) - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) + data := buffer.NewView(dataLen) + randomFill(data) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), + Data: data.ToVectorisedView(), + }) + copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -313,7 +319,7 @@ func TestSimpleSend(t *testing.T) { // Compare contents skipping the ethernet header added by the // endpoint. - merged := append(hdrBuf, buf...) + merged := append(hdrBuf, data...) if uint32(len(contents)) < pi.Size { t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) } @@ -340,14 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) { LocalLinkAddress: newLocalLinkAddress, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + ReserveHeaderBytes: header.EthernetMinimumSize, + }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -399,12 +405,12 @@ func TestFillTxQueue(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -418,11 +424,11 @@ func TestFillTxQueue(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -446,11 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -469,11 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -487,11 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -513,11 +519,11 @@ func TestFillTxMemory(t *testing.T) { // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -532,11 +538,11 @@ func TestFillTxMemory(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), }) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -560,11 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -575,23 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write a two-buffer packet. It must fail. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buffer.NewView(bufferSize).ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } // Attempt to write the one-buffer packet again. It must succeed. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..44f421c2d 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -76,9 +77,9 @@ func (t *tx) cleanup() { syscall.Munmap(t.data) } -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { +// transmit sends a packet made of bufs. Returns a boolean that specifies +// whether the packet was successfully transmitted. +func (t *tx) transmit(bufs ...buffer.View) bool { // Pull completions from the tx queue and add their buffers back to the // pool so that we can reuse them. for { @@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool { } bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) + total := uint32(0) + for _, data := range bufs { + total += uint32(len(data)) + } bufCount := (total + bSize - 1) / bSize // Allocate enough buffers to hold all the data. @@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool { // Copy data into allocated buffers. nBuf := buf var dBuf []byte - for _, data := range [][]byte{a, b} { + for _, data := range bufs { for len(data) > 0 { if len(dBuf) == 0 { dBuf = t.data[nBuf.Offset:][:nBuf.Size] diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 7cbc305e7..4aac12a8c 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index d9cd4e83a..560477926 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -123,13 +124,18 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) +} + func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Header.UsedLength() + pkt.Data.Size() + totalLength := pkt.Size() length := totalLength if max := int(e.maxPCAPLen); length > max { length = max @@ -150,12 +156,11 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw length -= n } } - write(pkt.Header.View()) - for _, view := range pkt.Data.Views() { + for _, v := range pkt.Views() { if length == 0 { break } - write(view) + write(v) } } } @@ -180,9 +185,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ + e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) return e.Endpoint.WriteRawPacket(vv) } @@ -191,53 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var transProto uint8 src := tcpip.Address("unknown") dst := tcpip.Address("unknown") - id := 0 - size := uint16(0) + var size uint16 + var id uint32 var fragmentOffset uint16 var moreFragments bool - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - + // Clone the packet buffer to not modify the original. + // + // We don't clone the original packet buffer so that the new packet buffer + // does not have any of its headers set. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return } - ipv4 := header.IPv4(hdr) + + ipv4 := header.IPv4(pkt.NetworkHeader().View()) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) - id = int(ipv4.ID()) + id = uint32(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) + proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return } - ipv6 := header.IPv6(hdr) + + ipv6 := header.IPv6(pkt.NetworkHeader().View()) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() - transProto = ipv6.NextHeader() + transProto = uint8(proto) size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + id = fragID + moreFragments = fragMore + fragmentOffset = fragOffset case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { + if parse.ARP(pkt) { return } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + + arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( "%s arp %s (%s) -> %s (%s) valid:%t", prefix, @@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { + if ok := parse.UDP(pkt); !ok { break } - udp := header.UDP(hdr) + + udp := header.UDP(pkt.TransportHeader().View()) if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { + if ok := parse.TCP(pkt); !ok { break } - tcp := header.TCP(hdr) + + tcp := header.TCP(pkt.TransportHeader().View()) if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > vv.Size() && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) + if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e0db6cf54..0243424f6 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,17 +1,32 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "tun_endpoint_refs", + out = "tun_endpoint_refs.go", + package = "tun", + prefix = "tunEndpoint", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "tunEndpoint", + }, +) + go_library( name = "tun", srcs = [ "device.go", "protocol.go", + "tun_endpoint_refs.go", "tun_unsafe.go", ], visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", + "//pkg/context", + "//pkg/log", "//pkg/refs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6bc9033d0..b6ddbe81e 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -18,7 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -64,14 +64,14 @@ func (d *Device) beforeSave() { } // Release implements fs.FileOperations.Release. -func (d *Device) Release() { +func (d *Device) Release(ctx context.Context) { d.mu.Lock() defer d.mu.Unlock() // Decrease refcount if there is an endpoint associated with this file. if d.endpoint != nil { d.endpoint.RemoveNotify(d.notifyHandle) - d.endpoint.DecRef() + d.endpoint.DecRef(ctx) d.endpoint = nil } } @@ -134,11 +134,13 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) + // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, nicID: id, name: name, + isTap: prefix == "tap", } endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { @@ -213,12 +215,11 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := &stack.PacketBuffer{ - Data: buffer.View(data).ToVectorisedView(), - } - if ethHdr != nil { - pkt.LinkHeader = buffer.View(ethHdr) - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: len(ethHdr), + Data: buffer.View(data).ToVectorisedView(), + }) + copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) endpoint.InjectLinkAddr(protocol, remote, pkt) return dataLen, nil } @@ -263,33 +264,22 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { return nil, false } // Ethernet header (TAP only). if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. - if info.Pkt.LinkHeader == nil { - hdr := &header.EthernetFields{ - SrcAddr: info.Route.LocalLinkAddress, - DstAddr: info.Route.RemoteLinkAddress, - Type: info.Proto, - } - if hdr.SrcAddr == "" { - hdr.SrcAddr = d.endpoint.LinkAddress() - } - - eth := make(header.Ethernet, header.EthernetMinimumSize) - eth.Encode(hdr) - vv.AppendView(buffer.View(eth)) - } else { - vv.AppendView(info.Pkt.LinkHeader) + if info.Pkt.LinkHeader().View().IsEmpty() { + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } + vv.AppendView(info.Pkt.LinkHeader().View()) } // Append upper headers. - vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + vv.AppendView(info.Pkt.NetworkHeader().View()) + vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. vv.Append(info.Pkt.Data) @@ -341,18 +331,52 @@ func (d *Device) WriteNotify() { // It is ref-counted as multiple opening files can attach to the same NIC. // The last owner is responsible for deleting the NIC. type tunEndpoint struct { + tunEndpointRefs *channel.Endpoint - refs.AtomicRefCount - stack *stack.Stack nicID tcpip.NICID name string + isTap bool } -// DecRef decrements refcount of e, removes NIC if refcount goes to 0. -func (e *tunEndpoint) DecRef() { - e.DecRefWithDestructor(func() { +// DecRef decrements refcount of e, removing NIC if it reaches 0. +func (e *tunEndpoint) DecRef(ctx context.Context) { + e.tunEndpointRefs.DecRef(func() { e.stack.RemoveNIC(e.nicID) }) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0956d2c65..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,6 +26,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 949b3f2b2..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 63bf40562..94827fc56 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -81,29 +86,39 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +135,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 6a4839fb8..59710352b 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -9,13 +9,16 @@ go_test( "ip_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", ], diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index eddf7b725..b40dde96b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/stack", ], ) @@ -28,5 +29,6 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7f27a840d..b47a7be51 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -15,20 +15,15 @@ // Package arp implements the ARP network protocol. It is used to resolve // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. -// -// To use it in the networking stack, pass arp.NewProtocol() as one of the -// network protocols when calling stack.New. Then add an "arp" address to every -// NIC on the stack that should respond to ARP requests. That is: -// -// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { -// // handle err -// } package arp import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -40,45 +35,74 @@ const ( ProtocolAddress = tcpip.Address("arp") ) -// endpoint implements stack.NetworkEndpoint. +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - protocol *protocol - nicID tcpip.NICID + stack.AddressableEndpointState + + protocol *protocol + + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache + nud stack.NUDHandler } -// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. -func (e *endpoint) DefaultTTL() uint8 { - return 0 +func (e *endpoint) Enable() *tcpip.Error { + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + e.setEnabled(true) + return nil } -func (e *endpoint) MTU() uint32 { - lmtu := e.linkEP.MTU() - return lmtu - uint32(e.MaxHeaderLength()) +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() } -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 } -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() +// setEnabled sets the enabled status for the endpoint. +func (e *endpoint) setEnabled(v bool) { + if v { + atomic.StoreUint32(&e.enabled, 1) + } else { + atomic.StoreUint32(&e.enabled, 0) + } } -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &stack.NetworkEndpointID{ProtocolAddress} +func (e *endpoint) Disable() { + e.setEnabled(false) } -func (e *endpoint) PrefixLen() int { +// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. +func (e *endpoint) DefaultTTL() uint8 { return 0 } +func (e *endpoint) MTU() uint32 { + lmtu := e.linkEP.MTU() + return lmtu - uint32(e.MaxHeaderLength()) +} + func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported @@ -86,7 +110,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return e.protocol.Number() + return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. @@ -99,7 +123,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.ARP(pkt.NetworkHeader) + if !e.isEnabled() { + return + } + + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -107,25 +135,58 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { - return // we have no useful answer, ignore the request + + if e.nud == nil { + if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + } else { + if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + + remoteAddr := tcpip.Address(h.ProtocolAddressSender()) + remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - packet := header.ARP(hdr.Prepend(header.ARPSize)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) - fallthrough // also fill the cache from requests + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + + if e.nud == nil { + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + return + } + + // The solicited, override, and isRouter flags are not available for ARP; + // they are only available for IPv6 Neighbor Advertisements. + e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + // Solicited and unsolicited (also referred to as gratuitous) ARP Replies + // are handled equivalently to a solicited Neighbor Advertisement. + Solicited: true, + // If a different link address is received than the one cached, the entry + // should always go to Stale. + Override: false, + // ARP does not distinguish between router and non-router hosts. + IsRouter: false, + }) } } @@ -142,16 +203,16 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - if addrWithPrefix.Address != ProtocolAddress { - return nil, tcpip.ErrBadLocalAddress - } - return &endpoint{ +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ protocol: p, - nicID: nicID, - linkEP: sender, + nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, - }, nil + nud: nud, + } + e.AddressableEndpointState.Init(e) + return e } // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. @@ -160,28 +221,31 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) copy(h.HardwareAddressSender(), linkEP.LinkAddress()) copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { - return broadcastMAC, true + return header.EthernetBroadcastAddress, true } if header.IsV4MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv4Address(addr), true @@ -190,12 +254,12 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } // SetOption implements stack.NetworkProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.NetworkProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -207,18 +271,14 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return 0, false, false - } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(header.ARPSize) - return 0, false, true + return 0, false, parse.ARP(pkt) } -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - // NewProtocol returns an ARP network protocol. -func NewProtocol() stack.NetworkProtocol { +// +// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" +// address must be added to every NIC that should respond to ARP requests. See +// ProtocolAddress for more details. +func NewProtocol(*stack.Stack) stack.NetworkProtocol { return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 66e67429c..626af975a 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -16,10 +16,12 @@ package arp_test import ( "context" + "fmt" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -32,54 +34,192 @@ import ( ) const ( + nicID = 1 + + stackAddr = tcpip.Address("\x0a\x00\x00\x01") stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + + remoteAddr = tcpip.Address("\x0a\x00\x00\x02") + remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + + unknownAddr = tcpip.Address("\x0a\x00\x00\x03") + + defaultChannelSize = 1 + defaultMTU = 65536 + + // eventChanSize defines the size of event channels used by the neighbor + // cache's event dispatcher. The size chosen here needs to be sufficient to + // queue all the events received during tests before consumption. + // If eventChanSize is too small, the tests may deadlock. + eventChanSize = 32 +) + +type eventType uint8 + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved ) +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + addr tcpip.Address + linkAddr tcpip.LinkAddress + state stack.NeighborState +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) +} + +// arpDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type arpDispatcher struct { + // C is where events are queued + C chan eventInfo +} + +var _ stack.NUDDispatcher = (*arpDispatcher)(nil) + +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + addr: addr, + linkAddr: linkAddr, + state: state, + } + d.C <- e +} + +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + addr: addr, + linkAddr: linkAddr, + state: state, + } + d.C <- e +} + +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + addr: addr, + linkAddr: linkAddr, + state: state, + } + d.C <- e +} + +func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { + select { + case got := <-d.C: + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + return fmt.Errorf("got invalid event (-got +want):\n%s", diff) + } + case <-ctx.Done(): + return fmt.Errorf("%s for %s", ctx.Err(), want) + } + return nil +} + +func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.waitForEvent(ctx, want) +} + +func (d *arpDispatcher) nextEvent() (eventInfo, bool) { + select { + case event := <-d.C: + return event, true + default: + return eventInfo{}, false + } +} + type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack + s *stack.Stack + linkEP *channel.Endpoint + nudDisp *arpDispatcher } -func newTestContext(t *testing.T) *testContext { +func newTestContext(t *testing.T, useNeighborCache bool) *testContext { + c := stack.DefaultNUDConfigurations() + // Transition from Reachable to Stale almost immediately to test if receiving + // probes refreshes positive reachability. + c.BaseReachableTime = time.Microsecond + + d := arpDispatcher{ + // Create an event channel large enough so the neighbor cache doesn't block + // while dispatching events. Blocking could interfere with the timing of + // NUD transitions. + C: make(chan eventInfo, eventChanSize), + } + s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + NUDConfigs: c, + NUDDisp: &d, + UseNeighborCache: useNeighborCache, }) - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNIC(nicID, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress for ipv4 failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if !useNeighborCache { + // The remote address needs to be assigned to the NIC so we can receive and + // verify outgoing ARP packets. The neighbor cache isn't concerned with + // this; the tests that use linkAddrCache expect the ARP responses to be + // received by the same NIC. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { t.Fatalf("AddAddress for arp failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, - NIC: 1, + NIC: nicID, }}) return &testContext{ - t: t, - s: s, - linkEP: ep, + s: s, + linkEP: ep, + nudDisp: &d, } } @@ -88,7 +228,7 @@ func (c *testContext) cleanup() { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := newTestContext(t, false /* useNeighborCache */) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -103,21 +243,21 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) } - for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { + for i, address := range []tcpip.Address{stackAddr, remoteAddr} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pi.Pkt.Header.View()) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) @@ -134,7 +274,7 @@ func TestDirectRequest(t *testing.T) { }) } - inject(stackAddrBad) + inject(unknownAddr) // Sleep tests are gross, but this will only potentially flake // if there's a bug. If there is no bug this will reliably // succeed. @@ -144,3 +284,182 @@ func TestDirectRequest(t *testing.T) { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } } + +func TestDirectRequestWithNeighborCache(t *testing.T) { + c := newTestContext(t, true /* useNeighborCache */) + defer c.cleanup() + + tests := []struct { + name string + senderAddr tcpip.Address + senderLinkAddr tcpip.LinkAddress + targetAddr tcpip.Address + isValid bool + }{ + { + name: "Loopback", + senderAddr: stackAddr, + senderLinkAddr: stackLinkAddr, + targetAddr: stackAddr, + isValid: true, + }, + { + name: "Remote", + senderAddr: remoteAddr, + senderLinkAddr: remoteLinkAddr, + targetAddr: stackAddr, + isValid: true, + }, + { + name: "RemoteInvalidTarget", + senderAddr: remoteAddr, + senderLinkAddr: remoteLinkAddr, + targetAddr: unknownAddr, + isValid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Inject an incoming ARP request. + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), test.senderLinkAddr) + copy(h.ProtocolAddressSender(), test.senderAddr) + copy(h.ProtocolAddressTarget(), test.targetAddr) + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + Data: v.ToVectorisedView(), + }) + + if !test.isValid { + // No packets should be sent after receiving an invalid ARP request. + // There is no need to perform a blocking read here, since packets are + // sent in the same function that handles ARP requests. + if pkt, ok := c.linkEP.Read(); ok { + t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) + } + return + } + + // Verify an ARP response was sent. + pi, ok := c.linkEP.Read() + if !ok { + t.Fatal("expected ARP response to be sent, got none") + } + + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) + } + rep := header.ARP(pi.Pkt.NetworkHeader().View()) + if !rep.IsValid() { + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { + t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { + t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { + t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want) + } + + // Verify the sender was saved in the neighbor cache. + wantEvent := eventInfo{ + eventType: entryAdded, + nicID: nicID, + addr: test.senderAddr, + linkAddr: tcpip.LinkAddress(test.senderLinkAddr), + state: stack.Stale, + } + if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { + t.Fatal(err) + } + + neighbors, err := c.s.Neighbors(nicID) + if err != nil { + t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff) + } + t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr) + } + neighborByAddr[n.Addr] = n + } + + neigh, ok := neighborByAddr[test.senderAddr] + if !ok { + t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr) + } + if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { + t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) + } + if got, want := neigh.LocalAddr, stackAddr; got != want { + t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) + } + if got, want := neigh.State, stack.Stale; got != want { + t.Errorf("got neighbor State = %s, want = %s", got, want) + } + + // No more events should be dispatched + for { + event, ok := c.nudDisp.nextEvent() + if !ok { + break + } + t.Errorf("unexpected %s", event) + } + }) + } +} + +func TestLinkAddressRequest(t *testing.T) { + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: remoteLinkAddr, + expectLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: header.EthernetBroadcastAddress, + }, + } + + for _, test := range tests { + p := arp.NewProtocol(nil) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index d1c728ccf..e247f06a4 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -41,5 +41,8 @@ go_test( "reassembler_test.go", ], library = ":fragmentation", - deps = ["//pkg/tcpip/buffer"], + deps = [ + "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", + ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 2982450f8..e1909fab0 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,28 +17,58 @@ package fragmentation import ( + "errors" "fmt" "log" "time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. -const DefaultReassembleTimeout = 30 * time.Second +const ( + // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. + DefaultReassembleTimeout = 30 * time.Second -// HighFragThreshold is the threshold at which we start trimming old -// fragmented packets. Linux uses a default value of 4 MB. See -// net.ipv4.ipfrag_high_thresh for more information. -const HighFragThreshold = 4 << 20 // 4MB + // HighFragThreshold is the threshold at which we start trimming old + // fragmented packets. Linux uses a default value of 4 MB. See + // net.ipv4.ipfrag_high_thresh for more information. + HighFragThreshold = 4 << 20 // 4MB -// LowFragThreshold is the threshold we reach to when we start dropping -// older fragmented packets. It's important that we keep enough room for newer -// packets to be re-assembled. Hence, this needs to be lower than -// HighFragThreshold enough. Linux uses a default value of 3 MB. See -// net.ipv4.ipfrag_low_thresh for more information. -const LowFragThreshold = 3 << 20 // 3MB + // LowFragThreshold is the threshold we reach to when we start dropping + // older fragmented packets. It's important that we keep enough room for newer + // packets to be re-assembled. Hence, this needs to be lower than + // HighFragThreshold enough. Linux uses a default value of 3 MB. See + // net.ipv4.ipfrag_low_thresh for more information. + LowFragThreshold = 3 << 20 // 3MB + + // minBlockSize is the minimum block size for fragments. + minBlockSize = 1 +) + +var ( + // ErrInvalidArgs indicates to the caller that that an invalid argument was + // provided. + ErrInvalidArgs = errors.New("invalid args") +) + +// FragmentID is the identifier for a fragment. +type FragmentID struct { + // Source is the source address of the fragment. + Source tcpip.Address + + // Destination is the destination address of the fragment. + Destination tcpip.Address + + // ID is the identification value of the fragment. + // + // This is a uint32 because IPv6 uses a 32-bit identification value. + ID uint32 + + // The protocol for the packet. + Protocol uint8 +} // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. @@ -46,14 +76,19 @@ type Fragmentation struct { mu sync.Mutex highLimit int lowLimit int - reassemblers map[uint32]*reassembler + reassemblers map[FragmentID]*reassembler rList reassemblerList size int timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job } // NewFragmentation creates a new Fragmentation. // +// blockSize specifies the fragment block size, in bytes. +// // highMemoryLimit specifies the limit on the memory consumed // by the fragments stored by Fragmentation (overhead of internal data-structures // is not accounted). Fragments are dropped when the limit is reached. @@ -64,7 +99,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -73,39 +108,81 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t lowMemoryLimit = 0 } - return &Fragmentation{ - reassemblers: make(map[uint32]*reassembler), + if blockSize < minBlockSize { + blockSize = minBlockSize + } + + f := &Fragmentation{ + reassemblers: make(map[FragmentID]*reassembler), highLimit: highMemoryLimit, lowLimit: lowMemoryLimit, timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, } + f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) + + return f } // Process processes an incoming fragment belonging to an ID and returns a -// complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { +// complete packet and its protocol number when all the packets belonging to +// that ID have been received. +// +// [first, last] is the range of the fragment bytes. +// +// first must be a multiple of the block size f is configured with. The size +// of the fragment data must be a multiple of the block size, unless there are +// no fragments following this fragment (more set to false). +// +// proto is the protocol number marked in the fragment being processed. It has +// to be given here outside of the FragmentID struct because IPv6 should not use +// the protocol to identify a fragment. +func (f *Fragmentation) Process( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + buffer.VectorisedView, uint8, bool, error) { + if first > last { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + } + + if first%f.blockSize != 0 { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + } + + fragmentSize := last - first + 1 + if more && fragmentSize%f.blockSize != 0 { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + } + + if l := vv.Size(); l < int(fragmentSize) { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + } + vv.CapLength(int(fragmentSize)) + f.mu.Lock() r, ok := f.reassemblers[id] - if ok && r.tooOld(f.timeout) { - // This is very likely to be an id-collision or someone performing a slow-rate attack. - f.release(r) - ok = false - } if !ok { - r = newReassembler(id) + r = newReassembler(id, f.clock) f.reassemblers[id] = r + wasEmpty := f.rList.Empty() f.rList.PushFront(r) + if wasEmpty { + // If we have just pushed a first reassembler into an empty list, we + // should kickstart the release job. The release job will keep + // rescheduling itself until the list becomes empty. + f.releaseReassemblersLocked() + } } f.mu.Unlock() - res, done, consumed, err := r.process(first, last, more, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() f.release(r) f.mu.Unlock() - return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed @@ -124,7 +201,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } } f.mu.Unlock() - return res, done, nil + return res, firstFragmentProto, done, nil } func (f *Fragmentation) release(r *reassembler) { @@ -142,3 +219,27 @@ func (f *Fragmentation) release(r *reassembler) { f.size = 0 } } + +// releaseReassemblersLocked releases already-expired reassemblers, then +// schedules the job to call back itself for the remaining reassemblers if +// any. This function must be called with f.mu locked. +func (f *Fragmentation) releaseReassemblersLocked() { + now := f.clock.NowMonotonic() + for { + // The reassembler at the end of the list is the oldest. + r := f.rList.Back() + if r == nil { + // The list is empty. + break + } + elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + if f.timeout > elapsed { + // If the oldest reassembler has not expired, schedule the release + // job so that this function is called back when it has expired. + f.releaseJob.Schedule(f.timeout - elapsed) + break + } + // If the oldest reassembler has already expired, release it. + f.release(r) + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 72c0f53be..189b223c5 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -15,11 +15,13 @@ package fragmentation import ( + "errors" "reflect" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) // vv is a helper to build VectorisedView from different strings. @@ -33,16 +35,18 @@ func vv(size int, pieces ...string) buffer.VectorisedView { } type processInput struct { - id uint32 + id FragmentID first uint16 last uint16 more bool + proto uint8 vv buffer.VectorisedView } type processOutput struct { - vv buffer.VectorisedView - done bool + vv buffer.VectorisedView + proto uint8 + done bool } var processTestCases = []struct { @@ -53,8 +57,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -62,12 +66,23 @@ var processTestCases = []struct { }, }, { + comment: "Next Header protocol mismatch", + in: []processInput{ + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), proto: 6, done: true}, + }, + }, + { comment: "Two IDs", in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -81,19 +96,27 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{}) + firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) if err != nil { - t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { - t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", + in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { + if firstFragmentProto != proto { + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", + in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) + } if _, ok := f.reassemblers[in.id]; ok { t.Errorf("Process(%d) did not remove buffer from reassemblers", i) } @@ -109,53 +132,154 @@ func TestFragmentationProcess(t *testing.T) { } func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(0, 0, 0, true, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) - if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + const ( + reassemblyTimeout = time.Millisecond + protocol = 0xff + ) + + type fragment struct { + first uint16 + last uint16 + more bool + data string + } + + type event struct { + // name is a nickname of this event. + name string + + // clockAdvance is a duration to advance the clock. The clock advances + // before a fragment specified in the fragment field is processed. + clockAdvance time.Duration + + // fragment is a fragment to process. This can be nil if there is no + // fragment to process. + fragment *fragment + + // expectDone is true if the fragmentation instance should report the + // reassembly is done after the fragment is processd. + expectDone bool + + // sizeAfterEvent is the expected size of the fragmentation instance after + // the event. + sizeAfterEvent int } - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") + + half1 := &fragment{first: 0, last: 0, more: true, data: "0"} + half2 := &fragment{first: 1, last: 1, more: false, data: "1"} + + tests := []struct { + name string + events []event + }{ + { + name: "half1 and half2 are reassembled successfully", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2", + fragment: half2, + expectDone: true, + sizeAfterEvent: 0, + }, + }, + }, + { + name: "half1 timeout, half2 timeout", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half1 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + { + name: "half2", + fragment: half2, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half2 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + for _, event := range test.events { + clock.Advance(event.clockAdvance) + if frag := event.fragment; frag != nil { + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + if err != nil { + t.Fatalf("%s: f.Process failed: %s", event.name, err) + } + if done != event.expectDone { + t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) + } + } + if got, want := f.size, event.sizeAfterEvent; got != want { + t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + } + } + }) } } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) // Send first fragment with id = 1. - f.Process(1, 0, 0, true, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) // Send first fragment with id = 2. - f.Process(2, 0, 0, true, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(3, 0, 0, true, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) - if _, ok := f.reassemblers[0]; ok { + if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") } - if _, ok := f.reassemblers[1]; ok { + if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { t.Errorf("Memory limits are not respected: id=1 has not been evicted.") } - if _, ok := f.reassemblers[3]; !ok { + if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") } } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Send the same packet again. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) got := f.size want := 1 @@ -163,3 +287,97 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } + +func TestErrors(t *testing.T) { + tests := []struct { + name string + blockSize uint16 + first uint16 + last uint16 + more bool + data string + err error + }{ + { + name: "exact block size without more", + blockSize: 2, + first: 2, + last: 3, + more: false, + data: "01", + }, + { + name: "exact block size with more", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "01", + }, + { + name: "exact block size with more and extra data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "012", + }, + { + name: "exact block size with more and too little data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size with more", + blockSize: 2, + first: 2, + last: 2, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size without more", + blockSize: 2, + first: 2, + last: 2, + more: false, + data: "0", + }, + { + name: "first not a multiple of block size", + blockSize: 2, + first: 3, + last: 4, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + { + name: "first more than last", + blockSize: 2, + first: 4, + last: 3, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{}) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) + if !errors.Is(err, test.err) { + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) + } + if done { + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 0a83d81f2..9bb051a30 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -32,23 +32,23 @@ type hole struct { type reassembler struct { reassemblerEntry - id uint32 + id FragmentID size int + proto uint8 mu sync.Mutex holes []hole deleted int heap fragHeap done bool - creationTime time.Time + creationTime int64 } -func newReassembler(id uint32) *reassembler { +func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, holes: make([]hole, 0, 16), - deleted: 0, heap: make(fragHeap, 0, 8), - creationTime: time.Now(), + creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +86,18 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed, nil + return buffer.VectorisedView{}, 0, false, consumed, nil + } + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded + // here (instead of just the protocol) because most IP options should be + // derived from the first fragment. + if first == 0 { + r.proto = proto } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,17 +107,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed, nil + return buffer.VectorisedView{}, 0, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) + return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) } - return res, true, consumed, nil -} - -func (r *reassembler) tooOld(timeout time.Duration) bool { - return time.Now().Sub(r.creationTime) > timeout + return res, r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index 7eee0710d..a0a04a027 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -18,6 +18,8 @@ import ( "math" "reflect" "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) type updateHolesInput struct { @@ -94,7 +96,7 @@ var holesTestCases = []struct { func TestUpdateHoles(t *testing.T) { for _, c := range holesTestCases { - r := newReassembler(0) + r := newReassembler(FragmentID{}, &faketime.NullClock{}) for _, i := range c.in { r.updateHoles(i.first, i.last, i.more) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7c8fb3e0a..6861cfdaf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -17,32 +17,44 @@ package ip_test import ( "testing" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( - localIpv4Addr = "\x0a\x00\x00\x01" - localIpv4PrefixLen = 24 - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - localIpv6PrefixLen = 120 - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = "\x0a\x00\x00\x01" + remoteIPv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + nicID = 1 ) +var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv4Addr, + PrefixLen: 24, +} + +var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv6Addr, + PrefixLen: 120, +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -96,9 +108,10 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ + return stack.TransportPacketHandled } // DeliverTransportControlPacket is called by network endpoints after parsing @@ -156,13 +169,13 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(pkt.Header.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(pkt.Header.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() @@ -172,60 +185,345 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStack() *stack.Stack { - return stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, +func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) + e := channel.New(0, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) + } + + v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) + } + + return s, e +} + +func buildDummyStack(t *testing.T) *stack.Stack { + t.Helper() + + s, _ := buildDummyStackWithLinkEndpoint(t) + return s +} + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + tester testObject + + mu struct { + sync.RWMutex + disabled bool + } +} + +func (*testInterface) ID() tcpip.NICID { + return nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (t *testInterface) Enabled() bool { + t.mu.RLock() + defer t.mu.RUnlock() + return !t.mu.disabled +} + +func (t *testInterface) setEnabled(v bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.mu.disabled = !v +} + +func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { + return &t.tester +} + +func TestSourceAddressValidation(t *testing.T) { + rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + srcAddress tcpip.Address + rxICMP func(*channel.Endpoint, tcpip.Address) + valid bool + }{ + { + name: "IPv4 valid", + srcAddress: "\x01\x02\x03\x04", + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 valid", + srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 multicast", + srcAddress: "\xe0\x00\x00\x01", + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv6 multicast", + srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + rxICMP: rxIPv6ICMP, + valid: false, + }, + { + name: "IPv4 broadcast", + srcAddress: header.IPv4Broadcast, + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv4 subnet broadcast", + srcAddress: func() tcpip.Address { + subnet := localIPv4AddrWithPrefix.Subnet() + return subnet.Broadcast() + }(), + rxICMP: rxIPv4ICMP, + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t) + test.rxICMP(e, test.srcAddress) + + var wantValid uint64 + if test.valid { + wantValid = 1 + } + + if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { + t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { + t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) + } + }) + } +} + +func TestEnableWhenNICDisabled(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protocolFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protocolFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var nic testInterface + nic.setEnabled(false) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, + }) + p := s.NetworkProtocolInstance(test.protoNum) + + // We pass nil for all parameters except the NetworkInterface and Stack + // since Enable only depends on these. + ep := p.NewEndpoint(&nic, nil, nil, nil) + + // The endpoint should initially be disabled, regardless the NIC's enabled + // status. + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Attempting to enable the endpoint while the NIC is disabled should + // fail. + nic.setEnabled(false) + if err := ep.Enable(); err != tcpip.ErrNotPermitted { + t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + } + // ep should consider the NIC's enabled status when determining its own + // enabled status so we "enable" the NIC to read just the endpoint's + // enabled status. + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Enabling the interface after the NIC has been enabled should succeed. + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + if !ep.Enabled() { + t.Fatal("got ep.Enabled() = false, want = true") + } + + // ep should consider the NIC's enabled status when determining its own + // enabled status. + nic.setEnabled(false) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Disabling the endpoint when the NIC is enabled should make the endpoint + // disabled. + nic.setEnabled(true) + ep.Disable() + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + }) + } } func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, } + ep := proto.NewEndpoint(&nic, nil, nil, nil) + defer ep.Close() // Allocate and initialize the payload view. payload := buffer.NewView(100) @@ -233,16 +531,19 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIPv4Addr + nic.tester.dstAddr = remoteIPv4Addr + nic.tester.contents = payload - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -250,20 +551,25 @@ func TestIPv4Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } totalLen := header.IPv4MinimumSize + 30 @@ -274,8 +580,8 @@ func TestIPv4Receive(t *testing.T) { TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. @@ -284,20 +590,24 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = view[header.IPv4MinimumSize:totalLen] - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -307,7 +617,7 @@ func TestIPv4ReceiveControl(t *testing.T) { name string expectedCount int fragmentOffset uint16 - code uint8 + code header.ICMPv4Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -321,20 +631,26 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") + r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") if err != nil { t.Fatal(err) } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) @@ -346,7 +662,7 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: uint8(header.ICMPv4ProtocolNumber), SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, + DstAddr: localIPv4Addr, }) // Create the ICMP header. @@ -364,8 +680,8 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, + SrcAddr: localIPv4Addr, + DstAddr: remoteIPv4Addr, }) // Make payload be non-zero. @@ -375,27 +691,35 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } } func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } totalLen := header.IPv4MinimumSize + 24 @@ -409,8 +733,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { Protocol: 10, FragmentOffset: 0, Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -425,8 +749,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -434,39 +758,54 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } // Send first segment. - pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag1.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.tester.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) } // Send second segment. - pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag2.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } func TestIPv6Send(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } // Allocate and initialize the payload view. @@ -475,16 +814,19 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIPv6Addr + nic.tester.dstAddr = remoteIPv6Addr + nic.tester.contents = payload - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -492,20 +834,24 @@ func TestIPv6Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } totalLen := header.IPv6MinimumSize + 30 @@ -515,8 +861,8 @@ func TestIPv6Receive(t *testing.T) { PayloadLength: uint16(totalLen - header.IPv6MinimumSize), NextHeader: 10, HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -525,21 +871,25 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv6Addr + nic.tester.dstAddr = localIPv6Addr + nic.tester.contents = view[header.IPv6MinimumSize:totalLen] - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -553,7 +903,7 @@ func TestIPv6ReceiveControl(t *testing.T) { expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type - code uint8 + code header.ICMPv6Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -570,7 +920,7 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } r, err := buildIPv6Route( - localIpv6Addr, + localIPv6Addr, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", ) if err != nil { @@ -578,15 +928,20 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, } - + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize @@ -600,7 +955,7 @@ func TestIPv6ReceiveControl(t *testing.T) { NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, SrcAddr: outerSrcAddr, - DstAddr: localIpv6Addr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -616,8 +971,8 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: 100, NextHeader: 10, HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, }) // Build the fragmentation header if needed. @@ -639,19 +994,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv6Addr + nic.tester.dstAddr = localIPv6Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } @@ -663,11 +1018,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // becomes Data. func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { v := view[:len(view)-trunc] - if len(v) < netHdrLen { - return &stack.PacketBuffer{Data: v.ToVectorisedView()} - } - return &stack.PacketBuffer{ - NetworkHeader: v[:netHdrLen], - Data: v[netHdrLen:].ToVectorisedView(), - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 78420d6e6..ee2c23e91 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,9 +10,11 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", @@ -26,14 +28,17 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1b67aa066..eab9a530c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,9 @@ package ipv4 import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -37,8 +40,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // false. // // Drop packet if it doesn't have the basic IPv4 header or if the - // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -53,7 +57,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { @@ -75,45 +79,92 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Echo.Increment() // Only send a reply if the checksum is valid. - wantChecksum := h.Checksum() - // Reset the checksum field to 0 to can calculate the proper - // checksum. We'll have to reset this before we hand the packet - // off. + headerChecksum := h.Checksum() h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - if gotChecksum != wantChecksum { - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) + calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(headerChecksum) + if calculatedChecksum != headerChecksum { + // It's possible that a raw socket still expects to receive this. e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ - Data: pkt.Data.Clone(nil), - NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + // DeliverTransportPacket will take ownership of pkt so don't use it beyond + // this point. Make a deep copy of the data before pkt gets sent as we will + // be modifying fields. + // + // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no + // waiting endpoints. Consider moving responsibility for doing the copy to + // DeliverTransportPacket so that is is only done when needed. + replyData := pkt.Data.ToOwnedView() + replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) + + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + + remoteLinkAddr := r.RemoteLinkAddress + + // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP + // source address MUST be one of its own IP addresses (but not a broadcast + // or multicast address). + localAddr := r.LocalAddress + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // Use the remote link address from the incoming packet. + r.ResolveWith(remoteLinkAddr) + + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the + // header information, we may have to change this code to handle the + // ICMP header no longer being in the data buffer. + + // Because IP and ICMP are so closely intertwined, we need to handcraft our + // IP header to be able to follow RFC 792. The wording on page 13 is as + // follows: + // IP Fields: + // Addresses + // The address of the source in an echo message will be the + // destination of the echo reply message. To form an echo reply + // message, the source and destination addresses are simply reversed, + // the type code changed to 0, and the checksum recomputed. + // + // This was interpreted by early implementors to mean that all options must + // be copied from the echo request IP header to the echo reply IP header + // and this behaviour is still relied upon by some applications. + // + // Create a copy of the IP header we received, options and all, and change + // The fields we need to alter. + // + // We need to produce the entire packet in the data segment in order to + // use WriteHeaderIncludedPacket(). + replyIPHdr.SetSourceAddress(r.LocalAddress) + replyIPHdr.SetDestinationAddress(r.RemoteAddress) + replyIPHdr.SetTTL(r.DefaultTTL()) + + replyICMPHdr := header.ICMPv4(replyData) + replyICMPHdr.SetType(header.ICMPv4EchoReply) + replyICMPHdr.SetChecksum(0) + replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0)) + + replyVV := buffer.View(replyIPHdr).ToVectorisedView() + replyVV.AppendView(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: replyVV, }) + replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - vv := pkt.Data.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + // The checksum will be calculated so we don't need to do it here. sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), - TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: vv, - TransportHeader: buffer.View(pkt), - }); err != nil { + if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return } @@ -129,6 +180,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) @@ -165,3 +219,177 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Invalid.Increment() } } + +// ======= ICMP Error packet generation ========= + +// icmpReason is a marker interface for IPv4 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// icmpReasonProtoUnreachable is an error where the transport protocol is +// not supported. +type icmpReasonProtoUnreachable struct{} + +func (*icmpReasonProtoUnreachable) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv4 and sends it back to the remote device that sent +// the problematic packet. It incorporates as much of that packet as +// possible as well as any error metadata as is available. returnError +// expects pkt to hold a valid IPv4 packet as per the wire format. +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + sent := r.Stats().ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + // We check we are responding only when we are allowed to. + // See RFC 1812 section 4.3.2.7 (shown below). + // + // ========= + // 4.3.2.7 When Not to Send ICMP Errors + // + // An ICMP error message MUST NOT be sent as the result of receiving: + // + // o An ICMP error message, or + // + // o A packet which fails the IP header validation tests described in + // Section [5.2.2] (except where that section specifically permits + // the sending of an ICMP error message), or + // + // o A packet destined to an IP broadcast or IP multicast address, or + // + // o A packet sent as a Link Layer broadcast or multicast, or + // + // o Any fragment of a datagram other then the first fragment (i.e., a + // packet for which the fragment offset in the IP header is nonzero). + // + // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in + // response to a non-initial fragment, but it currently can not happen. + + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + return nil + } + + networkHeader := pkt.NetworkHeader().View() + transportHeader := pkt.TransportHeader().View() + + // Don't respond to icmp error packets. + if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + // TODO(gvisor.dev/issue/3810): + // Unfortunately the current stack pretty much always has ICMPv4 headers + // in the Data section of the packet but there is no guarantee that is the + // case. If this is the case grab the header to make it like all other + // packet types. When this is cleaned up the Consume should be removed. + if transportHeader.IsEmpty() { + var ok bool + transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) + if !ok { + return nil + } + } else if transportHeader.Size() < header.ICMPv4MinimumSize { + return nil + } + // We need to decide to explicitly name the packets we can respond to or + // the ones we can not respond to. The decision is somewhat arbitrary and + // if problems arise this could be reversed. It was judged less of a breach + // of protocol to not respond to unknown non-error packets than to respond + // to unknown error packets so we take the first approach. + switch header.ICMPv4(transportHeader).Type() { + case + header.ICMPv4EchoReply, + header.ICMPv4Echo, + header.ICMPv4Timestamp, + header.ICMPv4TimestampReply, + header.ICMPv4InfoRequest, + header.ICMPv4InfoReply: + default: + // Assume any type we don't know about may be an error type. + return nil + } + } + + // Now work out how much of the triggering packet we should return. + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes. + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 and RFC 792 where it mentioned that at + // least 8 bytes of the payload must be included. Today linux and other + // systems implement the RFC 1812 definition and not the original + // requirement. We treat 8 bytes as the minimum but will try send more. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + + if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + return nil + } + + payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + + // The buffers used by pkt may be used elsewhere in the system. + // For example, an AF_RAW or AF_PACKET socket may use what the transport + // protocol considers an unreachable destination. Thus we deep copy pkt to + // prevent multiple ownership and SR errors. The new copy is a vectorized + // view with the entire incoming IP packet reassembled and truncated as + // required. This is now the payload of the new ICMP packet and no longer + // considered a packet in its own right. + newHeader := append(buffer.View(nil), networkHeader...) + newHeader = append(newHeader, transportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) + payload.CapLength(payloadLen) + + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + switch reason.(type) { + case *icmpReasonPortUnreachable: + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + case *icmpReasonProtoUnreachable: + icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + counter := sent.DstUnreachable + + if err := r.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + icmpPkt, + ); err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7e9f16c90..a2be64fb8 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -12,21 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv4 contains the implementation of the ipv4 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv4.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv4 contains the implementation of the ipv4 network protocol. package ipv4 import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -45,70 +42,139 @@ const ( // buckets is the number of identifier buckets. buckets = 2048 + + // The size of a fragment block, in bytes, as per RFC 791 section 3.1, + // page 14. + fragmentblockSize = 8 ) +var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() + +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - linkEP stack.LinkEndpoint - dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation - protocol *protocol - stack *stack.Stack + nic stack.NetworkInterface + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher + protocol *protocol + + // enabled is set to 1 when the enpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + } } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), - protocol: p, - stack: st, + nic: nic, + linkEP: nic.LinkEndpoint(), + dispatcher: dispatcher, + protocol: p, } + e.mu.addressableEndpointState.Init(e) + return e +} - return e, nil +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Create an endpoint to receive broadcast packets on this interface. + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + return err + } + // We have no need for the address endpoint. + ep.DecRef() + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) + return err } -// DefaultTTL is the default time-to-live value for this endpoint. -func (e *endpoint) DefaultTTL() uint8 { - return e.protocol.DefaultTTL() +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. -func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 } -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) + } + + // The address may have already been removed. + if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) + } } -// ID returns the ipv4 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id +// DefaultTTL is the default time-to-live value for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return e.protocol.DefaultTTL() } -// PrefixLen returns the ipv4 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize + return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // GSOMaxSize returns the maximum GSO packet size. @@ -125,14 +191,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in pkt.Header but does not -// assume that only the IP header is in pkt.Header. It assumes that the input -// packet's stated length matches the length of the header+payload. mtu -// includes the IP header and options. This does not support the DontFragment -// IP flag. +// write. It assumes that the IP header is already present in pkt.NetworkHeader. +// pkt.TransportHeader may be set. mtu includes the IP header and options. This +// does not support the DontFragment IP flag. func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.Header.View()) + ip := header.IPv4(pkt.NetworkHeader().View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -146,91 +210,89 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := pkt.Header.AvailableLength() + + // Keep the length reserved for link-layer, we need to create fragments with + // the same reserved length. + reservedForLink := pkt.AvailableHeaderBytes() + + // Destroy the packet, pull all payloads out for fragmentation. + transHeader, data := pkt.TransportHeader().View(), pkt.Data + + // Where possible, the first fragment that is sent has the same + // number of bytes reserved for header as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + transFitsFirst := len(transHeader) <= innerMTU + for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // pkt.Header.UsedLength() as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) + reserve := reservedForLink + int(ip.HeaderLength()) + if i == 0 && transFitsFirst { + // Reserve for transport header if it's going to be put in the first + // fragment. + reserve += len(transHeader) } + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: reserve, + }) + fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + + // Copy data for the fragment. + avail := innerMTU + + if n := len(transHeader); n > 0 { + if n > avail { + n = avail + } + if i == 0 && transFitsFirst { + copy(fragPkt.TransportHeader().Push(n), transHeader) + } else { + fragPkt.Data.AppendView(transHeader[:n:n]) + } + transHeader = transHeader[n:] + avail -= n + } + + if avail > 0 { + n := data.Size() + if n > avail { + n = avail + } + data.ReadToVV(&fragPkt.Data, n) + avail -= n + } + + copied := uint16(innerMTU - avail) + + // Set lengths in header and calculate checksum. + h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) + copy(h, ip) if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + copied) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := pkt.Data.Clone(nil) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes - // from the header. - if outerMTU >= pkt.Header.UsedLength() { - // This fragment can fit all of pkt.Header and possibly - // some of pkt.Data, too. - newPayload := pkt.Data.Clone(nil) - newPayloadLength := outerMTU - pkt.Header.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of pkt.Header. - startOfHdr := pkt.Header - startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of pkt.Header into the pkt.Data - // that remains to be sent. - restOfHdr := pkt.Header.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(pkt.Data) - pkt.Data = tmp + offset += copied + + // Send out the fragment. + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(n - i)) + return err } + r.Stats().IP.PacketsSent.Increment() } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -242,31 +304,33 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - return ip + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pkt, params) - nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() return nil } - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. - // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than - // only NATted packets, but removing this check short circuits broadcasts - // before they are sent out to other hosts. + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { - netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) @@ -282,10 +346,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.Stats().IP.PacketsSent.Increment() @@ -302,25 +367,28 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pkt, params) pkt = pkt.Next() } - nicName := e.stack.FindNICNameFromID(e.NICID()) + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - // Slow Path as we are dropping some packets in the batch degrade to + // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -328,8 +396,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -340,12 +408,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err } n++ } r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, nil + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -376,13 +448,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. @@ -396,33 +467,47 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } + if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - - ip = ip[:ip.HeaderLength()] - pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) - pkt.Data.TrimFront(int(ip.HeaderLength())) - return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv4(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + if !e.isEnabled() { + return + } + + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + // As per RFC 1122 section 3.2.1.3: + // When a host sends any datagram, the IP source address MUST + // be one of its own IP addresses (but not a broadcast or + // multicast address). + if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() return } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -430,18 +515,37 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < h.FragmentOffset() { + start := h.FragmentOffset() + // Drop the fragment if the size of the reassembled payload would exceed the + // maximum payload size. + // + // Note that this addition doesn't overflow even on 32bit architecture + // because pkt.Data.Size() should not exceed 65535 (the max IP datagram + // size). Otherwise the packet would've been rejected as invalid before + // reaching here. + if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data) + proto := h.Protocol() + pkt.Data, _, ready, err = e.protocol.fragmentation.Process( + // As per RFC 791 section 2.3, the identification value is unique + // for a source-destination pair and protocol. + fragmentation.FragmentID{ + Source: h.SourceAddress(), + Destination: h.DestinationAddress(), + ID: uint32(h.ID()), + Protocol: proto, + }, + start, + start+uint16(pkt.Data.Size())-1, + h.More(), + proto, + pkt.Data, + ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -451,27 +555,166 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } } + + r.Stats().IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - pkt.NetworkHeader.CapLength(int(h.HeaderLength())) + // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport + // headers, the setting of the transport number here should be + // unnecessary and removed. + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt) return } - r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, pkt) + + switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + case stack.TransportPacketHandled: + case stack.TransportPacketDestinationPortUnreachable: + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC: 1122 Section 3.2.2.1 + // A host SHOULD generate Destination Unreachable messages with code: + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported + _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt) + default: + panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) + } } // Close cleans up resources associated with the endpoint. -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + e.disableLocked() + e.mu.addressableEndpointState.Cleanup() +} + +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.RemovePermanentAddress(addr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + + loopback := e.nic.IsLoopback() + addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + // IPv4 has a notion of a subnet broadcast address and considers the + // loopback interface bound to an address's whole subnet (on linux). + return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) + }) + if addressEndpoint != nil { + return addressEndpoint + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) + if err != nil { + // AddAddress only returns an error if the address is already assigned, + // but we just checked above if the address exists so we expect no error. + panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) + } + return addressEndpoint +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV4MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) type protocol struct { - ids []uint32 - hashIV uint32 + stack *stack.Stack // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. defaultTTL uint32 + + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + + ids []uint32 + hashIV uint32 + + fragmentation *fragmentation.Fragmentation } // Number returns the ipv4 protocol number. @@ -496,10 +739,10 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case tcpip.DefaultTTLOption: - p.SetDefaultTTL(uint8(v)) + case *tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(*v)) return nil default: return tcpip.ErrUnknownProtocolOption @@ -507,7 +750,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) @@ -533,33 +776,28 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return 0, false, false } - ipHdr := header.IPv4(hdr) - // If there are options, pull those into hdr as well. - if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(headerLen) - if !ok { - panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) - } - ipHdr = header.IPv4(hdr) - } + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true +} - // If this is a fragment, don't bother parsing the transport header. - parseTransportHeader := true - if ipHdr.More() || ipHdr.FragmentOffset() != 0 { - parseTransportHeader = false - } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) - return ipHdr.TransportProtocol(), parseTransportHeader, true +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + if v { + atomic.StoreUint32(&p.forwarding, 1) + } else { + atomic.StoreUint32(&p.forwarding, 0) + } } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -583,7 +821,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui } // NewProtocol returns an IPv4 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -593,5 +831,11 @@ func NewProtocol() stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} + return &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 11e579c4b..712fbb861 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,17 +17,21 @@ package ipv4_test import ( "bytes" "encoding/hex" - "math/rand" + "math" + "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -35,8 +39,8 @@ import ( func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) const defaultMTU = 65536 @@ -91,25 +95,274 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s +// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and +// checks the response. +func TestIPv4Sanity(t *testing.T) { + const ( + defaultMTU = header.IPv6MinimumMTU + ttl = 255 + nicID = 1 + randomSequence = 123 + randomIdent = 42 + ) + var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + ) + + tests := []struct { + name string + headerLength uint8 // value of 0 means "use correct size" + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + shouldFail bool + expectICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + options []byte + }{ + { + name: "valid", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + }, + // The TTL tests check that we are not rejecting an incoming packet + // with a zero or one TTL, which has been a point of confusion in the + // past as RFC 791 says: "If this field contains the value zero, then the + // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies + // for the case of the destination host, stating as follows. + // + // A host MUST NOT send a datagram with a Time-to-Live (TTL) + // value of zero. + // + // A host MUST NOT discard a datagram just because it was + // received with TTL less than 2. + { + name: "zero TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 0, + shouldFail: false, + }, + { + name: "one TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + shouldFail: false, + }, + { + name: "End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{0, 0, 0, 0}, + }, + { + name: "NOP options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 1, 1}, + }, + { + name: "NOP and End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 0, 0}, + }, + { + name: "bad header length", + headerLength: header.IPv4MinimumSize - 1, + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (0)", + maxTotalLength: 0, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip + icmp - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad protocol", + maxTotalLength: defaultMTU, + transportProtocol: 99, + TTL: ttl, + shouldFail: true, + expectICMP: true, + ICMPType: header.ICMPv4DstUnreachable, + ICMPCode: header.ICMPv4ProtoUnreachable, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + } + + // Default routes for IPv4 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // Round up the header size to the next multiple of 4 as RFC 791, page 11 + // says: "Internet Header Length is the length of the internet header + // in 32 bit words..." and on page 23: "The internet header padding is + // used to ensure that the internet header ends on a 32 bit boundary." + ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1) + + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + + // Specify ident/seq to make sure we get the same in the response. + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + if test.maxTotalLength < totalLen { + totalLen = test.maxTotalLength + } + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHeaderLength), + TotalLength: totalLen, + Protocol: test.transportProtocol, + TTL: test.TTL, + SrcAddr: remoteIPv4Addr, + DstAddr: ipv4Addr.Address, + }) + if n := copy(ip.Options(), test.options); n != len(test.options) { + t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options)) + } + // Override the correct value if the test case specified one. + if test.headerLength != 0 { + ip.SetHeaderLength(test.headerLength) + } + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e.Read() + if !ok { + if test.shouldFail { + if test.expectICMP { + t.Fatal("expected ICMP error response missing") + } + return // Expected silent failure. + } + t.Fatal("expected ICMP echo reply missing") + } + + // Check the route that brought the packet to us. + if reply.Route.LocalAddress != ipv4Addr.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) + } + if reply.Route.RemoteAddress != remoteIPv4Addr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) + } + + // Make sure it's all in one buffer. + vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) + replyIPHeader := header.IPv4(vv.ToView()) + + // At this stage we only know it's an IP header so verify that much. + checker.IPv4(t, replyIPHeader, + checker.SrcAddr(ipv4Addr.Address), + checker.DstAddr(remoteIPv4Addr), + ) + + // All expected responses are ICMP packets. + if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { + t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + } + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) + + // Sanity check the response. + switch replyICMPHeader.Type() { + case header.ICMPv4DstUnreachable: + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + if !test.shouldFail || !test.expectICMP { + t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + return + case header.ICMPv4EchoReply: + checker.IPv4(t, replyIPHeader, + checker.IPv4HeaderLength(ipHeaderLength), + checker.IPv4Options(test.options), + checker.IPFullLength(uint16(requestPkt.Size())), + checker.ICMPv4( + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Seq(randomSequence), + checker.ICMPv4Ident(randomIdent), + checker.ICMPv4Checksum(), + ), + ) + if test.shouldFail { + t.Fatalf("unexpected Echo Reply packet\n") + } + default: + t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + } + }) } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload } // comparePayloads compared the contents of all the packets against the contents @@ -117,9 +370,9 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) + source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) + vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,8 +385,7 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -144,12 +396,22 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if i == 0 { + got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() + // sourcePacketInfo does not have NetworkHeader added, simulate one. + want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() + // Check that it kept the transport header in packet.TransportHeader if + // it fits in the first fragment. + if want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } + if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { + t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { @@ -172,101 +434,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type errorChannel struct { - *channel.Endpoint - Ch chan *stack.PacketBuffer - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan *stack.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { - stack.Route - linkEP *errorChannel -} - -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - return context{ - Route: r, - linkEP: ep, - } -} - func TestFragmentation(t *testing.T) { var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { manyPayloadViewsSizes[i] = 7 } fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + extraHeaderReserveLength int + payloadViewsSizes []int + expectedFrags int }{ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, @@ -281,43 +461,29 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := &stack.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - } - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + source := pkt.Clone() + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) if err != nil { - t.Errorf("err got %v, want %v", err, nil) + t.Errorf("got err = %s, want = nil", err) } - var results []*stack.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } + if got := len(ep.WrittenPackets); got != ft.expectedFrags { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - compareFragments(t, results, source, ft.mtu) + compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } } @@ -328,155 +494,376 @@ func TestFragmentationErrors(t *testing.T) { fragTests := []struct { description string mtu uint32 - hdrLength int + transportHeaderLength int payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error + err *tcpip.Error + allowPackets int + fragmentCount int }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + { + description: "NoFrag", + mtu: 2000, + transportHeaderLength: 0, + payloadViewsSizes: []int{1000}, + err: tcpip.ErrAborted, + allowPackets: 0, + fragmentCount: 1, + }, + { + description: "ErrorOnFirstFrag", + mtu: 500, + transportHeaderLength: 0, + payloadViewsSizes: []int{1000}, + err: tcpip.ErrAborted, + allowPackets: 0, + fragmentCount: 3, + }, + { + description: "ErrorOnSecondFrag", + mtu: 500, + transportHeaderLength: 0, + payloadViewsSizes: []int{1000}, + err: tcpip.ErrAborted, + allowPackets: 1, + fragmentCount: 3, + }, + { + description: "ErrorOnFirstFragMTUSmallerThanHeader", + mtu: 500, + transportHeaderLength: 1000, + payloadViewsSizes: []int{500}, + err: tcpip.ErrAborted, + allowPackets: 0, + fragmentCount: 4, + }, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + r := buildRoute(t, ep) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } + }, pkt) + if err != ft.err { + t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) } }) } } func TestInvalidFragments(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 6 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + autoChecksum bool // if true, the Checksum field will be overwritten. + } + // These packets have both IHL and TotalLength set to 0. - testCases := []struct { + tests := []struct { name string - packets [][]byte + fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 }{ { - "ihl_totallen_zero_valid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - "ihl_totallen_zero_invalid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset non-zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, + FragmentOffset: 59776, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, }, - 1, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, }, { - // Total Length of 37(20 bytes IP header + 17 bytes of - // payload) - // Frag Offset of 0x1ffe = 8190*8 = 65520 - // Leading to the fragment end to be past 65535. - "ihl_totallen_valid_invalid_frag_offset_1", - [][]byte{ - {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, }, - // The following 3 tests were found by running a fuzzer and were - // triggering a panic in the IPv4 reassembler code. { - "ihl_less_than_ipv4_minimum_size_1", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + // Payload 17 octets and Fragment offset 65520 + // Leading to the fragment end to be past 65536. + name: "fragment ends past 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 17, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(17), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - "ihl_less_than_ipv4_minimum_size_2", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + // Payload 16 octets and fragment offset 65520 + // Leading to the fragment end to be exactly 65536. + name: "fragment ends exactly at 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(16), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 0, + wantMalformedFragments: 0, }, { - "ihl_less_than_ipv4_minimum_size_3", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL less than IPv4 minimum size", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 1944, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize - 12, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 2, + wantMalformedFragments: 0, }, { - "fragment_with_short_total_len_extra_payload", - [][]byte{ - {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "fragment with short TotalLength and extra payload", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 28816, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 4, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - "multiple_fragments_with_more_fragments_set_to_false", - [][]byte{ - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + name: "multiple fragments with More Fragments flag set to false", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 128, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - const nicID tcpip.NICID = 42 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, }, }) + e := channel.New(0, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } - var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) - var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) - ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicID, sniffer.New(ep)) + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + len(f.payload) + hdr := buffer.NewPrependable(pktSize) - for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + copy(ip[header.IPv4MinimumSize:], f.payload) + + if f.autoChecksum { + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + } + + vv := hdr.View().ToVectorisedView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) } - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) } }) @@ -486,12 +873,16 @@ func TestInvalidFragments(t *testing.T) { // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { - const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - const nicID = 1 + const ( + nicID = 1 + + addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 + addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 + addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + ) // Build and return a UDP header containing payload. - udpGen := func(payloadLen int, multiplier uint8) buffer.View { + udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { payload := buffer.NewView(payloadLen) for i := 0; i < len(payload); i++ { payload[i] = uint8(i) * multiplier @@ -507,20 +898,32 @@ func TestReceiveFragments(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) sum = header.Checksum(payload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) return hdr.View() } // UDP header plus a payload of 0..256 - ipv4Payload1 := udpGen(256, 1) - udpPayload1 := ipv4Payload1[header.UDPMinimumSize:] + ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) + udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] + ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) + udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] // UDP header plus a payload of 0..256 in increments of 2. - ipv4Payload2 := udpGen(128, 2) - udpPayload2 := ipv4Payload2[header.UDPMinimumSize:] + ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) + udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] + // UDP header plus a payload of 0..256 in increments of 3. + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 791 section 3.1 page 14). + ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) + udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] + // Used to test the max reassembled payload length (65,535 octets). + ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) + udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] type fragmentData struct { + srcAddr tcpip.Address + dstAddr tcpip.Address id uint16 flags uint8 fragmentOffset uint16 @@ -536,22 +939,40 @@ func TestReceiveFragments(t *testing.T) { name: "No fragmentation", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 0, - payload: ipv4Payload1, + payload: ipv4Payload1Addr1ToAddr2, }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "No fragmentation with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2, + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, }, { name: "More fragments without payload", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1, + payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: nil, @@ -560,10 +981,12 @@ func TestReceiveFragments(t *testing.T) { name: "Non-zero fragment offset without payload", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 8, - payload: ipv4Payload1, + payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: nil, @@ -572,34 +995,108 @@ func TestReceiveFragments(t *testing.T) { name: "Two fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments out of order", + fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload3Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2[:63], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 63, + payload: ipv4Payload3Addr1ToAddr2[63:], + }, + }, + expectedPayloads: nil, }, { name: "Second fragment has MoreFlags set", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: nil, @@ -608,16 +1105,20 @@ func TestReceiveFragments(t *testing.T) { name: "Two fragments with different IDs", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 2, flags: 0, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: nil, @@ -626,52 +1127,122 @@ func TestReceiveFragments(t *testing.T) { name: "Two interleaved fragmented packets", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 2, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload2[:64], + payload: ipv4Payload2Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], }, { + srcAddr: addr1, + dstAddr: addr2, id: 2, flags: 0, fragmentOffset: 64, - payload: ipv4Payload2[64:], + payload: ipv4Payload2Addr1ToAddr2[64:], }, }, - expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets from different sources but with same ID", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr3, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr3ToAddr2[:32], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr3, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 32, + payload: ipv4Payload1Addr3ToAddr2[32:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, }, { name: "Fragment without followup", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, }, expectedPayloads: nil, }, + { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Setup a stack and endpoint. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -711,16 +1282,16 @@ func TestReceiveFragments(t *testing.T) { FragmentOffset: frag.fragmentOffset, TTL: 64, Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: addr1, - DstAddr: addr2, + SrcAddr: frag.srcAddr, + DstAddr: frag.dstAddr, }) vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { @@ -743,3 +1314,189 @@ func TestReceiveFragments(t *testing.T) { }) } } + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + // Parameterize the tests to run with both WritePacket and WritePackets. + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + } + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatalf("NewSubnet(_, _) failed: %v", err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) + } + return rt +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 3f71fc520..97adbcbd4 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -5,16 +5,19 @@ package(licenses = ["notice"]) go_library( name = "ipv6", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "ndp.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", - "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", ], ) @@ -35,10 +38,11 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go index d199ded6a..09ba133b1 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go @@ -14,7 +14,7 @@ // Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. -package stack +package ipv6 import "strconv" diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2ff7eedf4..8e9def6b8 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -39,8 +39,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // is truncated, which would cause IsValid to return false. // // Drop packet if it doesn't have the basic IPv6 header or if the - // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -67,7 +68,60 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) +} + +// getLinkAddrOption searches NDP options for a given link address option using +// the provided getAddr function as a filter. Returns the link address if +// found; otherwise, returns the zero link address value. Also returns true if +// the options are valid as per the wire format, false otherwise. +func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) { + var linkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + return "", false + } + if done { + break + } + if addr := getAddr(opt); len(addr) != 0 { + // No RFCs define what to do when an NDP message has multiple Link-Layer + // Address options. Since no interface can have multiple link-layer + // addresses, we consider such messages invalid. + if len(linkAddr) != 0 { + return "", false + } + linkAddr = addr + } + } + return linkAddr, true +} + +// getSourceLinkAddr searches NDP options for the source link address option. +// Returns the link address if found; otherwise, returns the zero link address +// value. Also returns true if the options are valid as per the wire format, +// false otherwise. +func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { + return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress { + if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok { + return src.EthernetAddress() + } + return "" + }) +} + +// getTargetLinkAddr searches NDP options for the target link address option. +// Returns the link address if found; otherwise, returns the zero link address +// value. Also returns true if the options are valid as per the wire format, +// false otherwise. +func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { + return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress { + if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok { + return dst.EthernetAddress() + } + return "" + }) } func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { @@ -83,7 +137,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } h := header.ICMPv6(v) - iph := header.IPv6(pkt.NetworkHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // @@ -128,13 +182,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } @@ -144,22 +200,16 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // NDP messages cannot be fragmented. Also note that in the common case NDP // datagrams are very small and ToView() will not incur allocations. ns := header.NDPNeighborSolicit(payload.ToView()) - it, err := ns.Options().Iter(true) - if err != nil { - // If we have a malformed NDP NS option, drop the packet. + targetAddr := ns.TargetAddress() + + // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast + // address. + if header.IsV6MulticastAddress(targetAddr) { received.Invalid.Increment() return } - targetAddr := ns.TargetAddress() - s := r.Stack() - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now, drop this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is // attempting to perform address resolution on the target. In this case, @@ -172,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // stack know so it can handle such a scenario and do nothing further with // the NS. if r.RemoteAddress == header.IPv6Any { - s.DupTentativeAddrDetected(e.nicID, targetAddr) + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } } // Do not handle neighbor solicitations targeted to an address that is @@ -184,39 +247,22 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // so the packet is processed as defined in RFC 4861, as per RFC 4862 // section 5.4.3. - // Is the NS targetting us? - if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + // Is the NS targeting us? + if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } - // If the NS message contains the Source Link-Layer Address option, update - // the link address cache with the value of the option. - // - // TODO(b/148429853): Properly process the NS message and do Neighbor - // Unreachability Detection. - var sourceLinkAddr tcpip.LinkAddress - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) - } - if done { - break - } - - switch opt := opt.(type) { - case header.NDPSourceLinkLayerAddressOption: - // No RFCs define what to do when an NS message has multiple Source - // Link-Layer Address options. Since no interface can have multiple - // link-layer addresses, we consider such messages invalid. - if len(sourceLinkAddr) != 0 { - received.Invalid.Increment() - return - } + it, err := ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. + received.Invalid.Increment() + return + } - sourceLinkAddr = opt.EthernetAddress() - } + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return } unspecifiedSource := r.RemoteAddress == header.IPv6Any @@ -234,8 +280,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } else if unspecifiedSource { received.Invalid.Increment() return + } else if e.nud != nil { + e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } // ICMPv6 Neighbor Solicit messages are always sent to @@ -274,8 +322,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + }) + packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -291,9 +342,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -301,7 +350,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize { received.Invalid.Increment() return } @@ -311,28 +360,34 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // 5, NDP messages cannot be fragmented. Also note that in the common case // NDP datagrams are very small and ToView() will not incur allocations. na := header.NDPNeighborAdvert(payload.ToView()) - it, err := na.Options().Iter(true) - if err != nil { - // If we have a malformed NDP NA option, drop the packet. - received.Invalid.Increment() - return - } - targetAddr := na.TargetAddress() - stack := r.Stack() - - if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now short-circuit this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - stack.DupTentativeAddrDetected(e.nicID, targetAddr) + // + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } + return + } + + it, err := na.Options().Iter(false /* check */) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.Invalid.Increment() return } @@ -345,58 +400,64 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // TODO(b/143147598): Handle the scenario described above. Also inform the // netstack integration that a duplicate address was detected outside of // DAD. + targetLinkAddr, ok := getTargetLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - // - // TODO(b/148429853): Properly process the NA message and do Neighbor - // Unreachability Detection. - var targetLinkAddr tcpip.LinkAddress - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) - } - if done { - break - } - - switch opt := opt.(type) { - case header.NDPTargetLinkLayerAddressOption: - // No RFCs define what to do when an NA message has multiple Target - // Link-Layer Address options. Since no interface can have multiple - // link-layer addresses, we consider such messages invalid. - if len(targetLinkAddr) != 0 { - received.Invalid.Increment() - return - } - - targetLinkAddr = opt.EthernetAddress() + if len(targetLinkAddr) != 0 { + if e.nud == nil { + e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) + return } - } - if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + Solicited: na.SolicitedFlag(), + Override: na.OverrideFlag(), + IsRouter: na.RouterFlag(), + }) } case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { received.Invalid.Increment() return } - pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + + remoteLinkAddr := r.RemoteLinkAddress + + // As per RFC 4291 section 2.7, multicast addresses must not be used as + // source addresses in IPv6 packets. + localAddr := r.LocalAddress + if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // Use the link address from the source of the original packet. + r.ResolveWith(remoteLinkAddr) + + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: pkt.Data, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -418,27 +479,75 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6RouterSolicit: received.RouterSolicit.Increment() - if !isNDPValid() { + + // + // Validate the RS as per RFC 4861 section 6.1.1. + // + + // Is the NDP payload of sufficient size to hold a Router Solictation? + if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.Invalid.Increment() return } - case header.ICMPv6RouterAdvert: - received.RouterAdvert.Increment() + stack := r.Stack() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + // Is the networking stack operating as a router? + if !stack.Forwarding(ProtocolNumber) { + // ... No, silently drop the packet. + received.RouterOnlyPacketsDroppedByHost.Increment() + return + } + + // Note that in the common case NDP datagrams are very small and ToView() + // will not incur allocations. + rs := header.NDPRouterSolicit(payload.ToView()) + it, err := rs.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. received.Invalid.Increment() return } - routerAddr := iph.SourceAddress() + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } + + // If the RS message has the source link layer option, update the link + // address cache with the link address for the source of the message. + if len(sourceLinkAddr) != 0 { + // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, it SHOULD be included on link layers that have addresses. + if r.RemoteAddress == header.IPv6Any { + received.Invalid.Increment() + return + } + + if e.nud != nil { + // A RS with a specified source IP address modifies the NUD state + // machine in the same way a reachability probe would. + e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + } + } + + case header.ICMPv6RouterAdvert: + received.RouterAdvert.Increment() // // Validate the RA as per RFC 4861 section 6.1.2. // + // Is the NDP payload of sufficient size to hold a Router Advertisement? + if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + received.Invalid.Increment() + return + } + + routerAddr := iph.SourceAddress() + // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { // ...No, silently drop the packet. @@ -446,16 +555,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. + // Note that in the common case NDP datagrams are very small and ToView() + // will not incur allocations. ra := header.NDPRouterAdvert(payload.ToView()) - opts := ra.Options() + it, err := ra.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. + received.Invalid.Increment() + return + } - // Are options valid as per the wire format? - if _, err := opts.Iter(true); err != nil { - // ...No, silently drop the packet. + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { received.Invalid.Increment() return } @@ -465,12 +576,33 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // as RFC 4861 section 6.1.2 is concerned. // - // Tell the NIC to handle the RA. - stack := r.Stack() - rxNICID := r.NICID() - stack.HandleNDPRA(rxNICID, routerAddr, ra) + // If the RA has the source link layer option, update the link address + // cache with the link address for the advertised router. + if len(sourceLinkAddr) != 0 && e.nud != nil { + e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + } + + e.mu.Lock() + e.mu.ndp.handleRA(routerAddr, ra) + e.mu.Unlock() case header.ICMPv6RedirectMsg: + // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating + // this redirect message, as per RFC 4871 section 7.3.3: + // + // "A Neighbor Cache entry enters the STALE state when created as a + // result of receiving packets other than solicited Neighbor + // Advertisements (i.e., Router Solicitations, Router Advertisements, + // Redirects, and Neighbor Solicitations). These packets contain the + // link-layer address of either the sender or, in the case of Redirect, + // the redirection target. However, receipt of these link-layer + // addresses does not confirm reachability of the forward-direction path + // to that node. Placing a newly created Neighbor Cache entry for which + // the link-layer address is known in the STALE state provides assurance + // that path failures are detected quickly. In addition, should a cached + // link-layer address be modified due to receiving one of the above + // messages, the state SHOULD also be set to STALE to provide prompt + // verification that the path to the new link-layer address is working." received.RedirectMsg.Increment() if !isNDPValid() { received.Invalid.Increment() @@ -494,8 +626,6 @@ const ( icmpV6LengthOffset = 25 ) -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. @@ -504,7 +634,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) // TODO(b/148672031): Use stack.FindRoute instead of manually creating the @@ -513,19 +643,26 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + }) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + icmpHdr.SetType(header.ICMPv6NeighborSolicit) + copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) + icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr + icmpHdr[icmpV6LengthOffset] = 1 + copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -535,9 +672,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -547,3 +682,159 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return tcpip.LinkAddress([]byte(nil)), false } + +// ======= ICMP Error packet generation ========= + +// icmpReason is a marker interface for IPv6 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonParameterProblem is an error during processing of extension headers +// or the fixed header defined in RFC 4443 section 3.4. +type icmpReasonParameterProblem struct { + code header.ICMPv6Code + + // respondToMulticast indicates that we are sending a packet that falls under + // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondToMulticast bool + + // pointer is defined in the RFC 4443 setion 3.4 which reads: + // + // Pointer Identifies the octet offset within the invoking packet + // where the error was detected. + // + // The pointer will point beyond the end of the ICMPv6 + // packet if the field in error is beyond what can fit + // in the maximum size of an ICMPv6 error message. + pointer uint32 +} + +func (*icmpReasonParameterProblem) isICMPReason() {} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv6 and sends it. +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + stats := r.Stats().ICMP + sent := stats.V6PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + // Only send ICMP error if the address is not a multicast v6 + // address and the source is not the unspecified address. + // + // There are exceptions to this rule. + // See: point e.3) RFC 4443 section-2.4 + // + // (e) An ICMPv6 error message MUST NOT be originated as a result of + // receiving the following: + // + // (e.1) An ICMPv6 error message. + // + // (e.2) An ICMPv6 redirect message [IPv6-DISC]. + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + // + var allowResponseToMulticast bool + if reason, ok := reason.(*icmpReasonParameterProblem); ok { + allowResponseToMulticast = reason.respondToMulticast + } + + if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + return nil + } + + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + + if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { + // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. + // Unfortunately at this time ICMP Packets do not have a transport + // header separated out. It is in the Data part so we need to + // separate it out now. We will just pretend it is a minimal length + // ICMP packet as we don't really care if any later bits of a + // larger ICMP packet are in the header view or in the Data view. + transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) + if !ok { + return nil + } + typ := header.ICMPv6(transport).Type() + if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { + return nil + } + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + available := int(mtu) - headerLen + if available < header.IPv6MinimumSize { + return nil + } + payloadLen := network.Size() + transport.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload.CapLength(payloadLen) + + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + + icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + icmpHdr.SetType(header.ICMPv6ParamProblem) + icmpHdr.SetCode(reason.code) + icmpHdr.SetTypeSpecific(reason.pointer) + counter = sent.ParamProblem + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + counter = sent.DstUnreachable + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) + err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) + if err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 52a01b44e..31370c1d4 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -31,9 +31,14 @@ import ( ) const ( + nicID = 1 + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + + defaultChannelSize = 1 + defaultMTU = 65536 ) var ( @@ -46,7 +51,10 @@ type stubLinkEndpoint struct { } func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 + // Indicate that resolution for link layer addresses is required to send + // packets over this link. This is needed so the NIC knows to allocate a + // neighbor table. + return stack.CapabilityResolutionRequired } func (*stubLinkEndpoint) MaxHeaderLength() uint16 { @@ -67,7 +75,8 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { + return stack.TransportPacketHandled } type stubLinkAddressCache struct { @@ -81,16 +90,212 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { } +type stubNUDHandler struct{} + +var _ stack.NUDHandler = (*stubNUDHandler)(nil) + +func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +} + +func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +} + +func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +} + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestICMPCounts(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: test.useNeighborCache, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + typ header.ICMPv6Type + size int + extraData []byte + }{ + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, + } + + handleIPv6Payload := func(icmp header.ICMPv6) { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(&r, pkt) + } + + for _, typ := range types { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) + } + + // Construct an empty ICMP packet so that + // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + + icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { + if got, want := s.Value(), uint64(1); got != want { + t.Errorf("got %s = %d, want = %d", name, got, want) + } + }) + if t.Failed() { + t.Logf("stats:\n%+v", s.Stats()) + } + }) + } +} + +func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: true, }) { - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } } @@ -102,7 +307,7 @@ func TestICMPCounts(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -111,14 +316,16 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) } defer r.Release() @@ -180,7 +387,11 @@ func TestICMPCounts(t *testing.T) { } handleIPv6Payload := func(icmp header.ICMPv6) { - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -188,10 +399,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: buffer.View(icmp).ToVectorisedView(), - }) + ep.HandlePacket(&r, pkt) } for _, typ := range types { @@ -248,35 +456,34 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), } - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { + if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { + if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { + if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } @@ -287,7 +494,7 @@ func newTestContext(t *testing.T) *testContext { c.s0.SetRouteTable( []tcpip.Route{{ Destination: subnet0, - NIC: 1, + NIC: nicID, }}, ) subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) @@ -297,7 +504,7 @@ func newTestContext(t *testing.T) *testContext { c.s1.SetRouteTable( []tcpip.Route{{ Destination: subnet1, - NIC: 1, + NIC: nicID, }}, ) @@ -321,12 +528,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ - Data: vv, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } if pi.Proto != ProtocolNumber { @@ -338,7 +543,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } - ipv6 := header.IPv6(pi.Pkt.Header.View()) + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -358,9 +565,9 @@ func TestLinkResolution(t *testing.T) { c := newTestContext(t) defer c.cleanup() - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) } defer r.Release() @@ -375,14 +582,14 @@ func TestLinkResolution(t *testing.T) { var wq waiter.Queue ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) + _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}) if resCh != nil { if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) + t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err) } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, @@ -398,7 +605,7 @@ func TestLinkResolution(t *testing.T) { continue } if err != nil { - t.Fatalf("ep.Write(_) = _, _, %s", err) + t.Fatalf("ep.Write(_) = (_, _, %s)", err) } break } @@ -423,6 +630,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { size int extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool }{ { name: "DstUnreachable", @@ -479,6 +687,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) { statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }, + // Hosts MUST silently discard any received Router Solicitation messages. + routerOnly: true, }, { name: "RouterAdvert", @@ -515,83 +725,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }, } - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" + } + t.Run(name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr0) + + // Indicate that resolution for link layer addresses is required to + // send packets over this link. This is needed so the NIC knows to + // allocate a neighbor table. + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: test.useNeighborCache, + }) + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + if checksum { + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + } + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), + }) + e.InjectInbound(ProtocolNumber, pkt) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Router only count should not have increased. + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if !isRouter && typ.routerOnly && test.useNeighborCache { + // Router only count should have increased. + if got := routerOnly.Value(); got != 1 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) + } + } + }) } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) } }) } @@ -692,13 +952,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } { @@ -709,7 +969,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -717,12 +977,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -733,9 +993,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -747,7 +1008,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Without setting checksum, the incoming packet should @@ -758,13 +1019,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { } // Rx count of type typ.typ should not have increased. if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // When checksum is set, it should be received. handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Invalid count should not have increased again. if got := invalid.Value(); got != 1 { @@ -869,14 +1130,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -886,21 +1147,21 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -911,9 +1172,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -925,7 +1187,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Without setting checksum, the incoming packet should @@ -936,13 +1198,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } // Rx count of type typ.typ should not have increased. if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // When checksum is set, it should be received. handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Invalid count should not have increased again. if got := invalid.Value(); got != 1 { @@ -951,3 +1213,50 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { }) } } + +func TestLinkAddressRequest(t *testing.T) { + snaddr := header.SolicitedNodeAddr(lladdr0) + mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) + + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectLinkAddr: linkAddr1, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: mcaddr, + }, + } + + for _, test := range tests { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 95fbcf2d1..c8a3e0b34 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,23 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv6 contains the implementation of the ipv6 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv6.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv6 contains the implementation of the ipv6 network protocol. package ipv6 import ( "fmt" + "sort" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" - "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -45,46 +42,313 @@ const ( DefaultTTL = 64 ) +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) +var _ stack.NDPEndpoint = (*endpoint)(nil) +var _ NDPEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache + nud stack.NUDHandler dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation protocol *protocol + stack *stack.Stack + + // enabled is set to 1 when the endpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + ndp ndpState + } } -// DefaultTTL is the default hop limit for this endpoint. -func (e *endpoint) DefaultTTL() uint8 { - return e.protocol.DefaultTTL() +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it is passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. -func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) +// InvalidateDefaultRouter implements stack.NDPEndpoint. +func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.invalidateDefaultRouter(rtr) +} + +// SetNDPConfigurations implements NDPEndpoint. +func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) { + c.validate() + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.configs = c +} + +// hasTentativeAddr returns true if addr is tentative on e. +func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { + e.mu.RLock() + addressEndpoint := e.getAddressRLocked(addr) + e.mu.RUnlock() + return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative +} + +// dupTentativeAddrDetected attempts to inform e that a tentative addr is a +// duplicate on a link. +// +// dupTentativeAddrDetected removes the tentative address if it exists. If the +// address was generated via SLAAC, an attempt is made to generate a new +// address. +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return tcpip.ErrBadAddress + } + + if addressEndpoint.GetKind() != stack.PermanentTentative { + return tcpip.ErrInvalidEndpointState + } + + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + return err + } + + prefix := addressEndpoint.AddressWithPrefix().Subnet() + + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil +} + +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.Enabled() { + return + } + + if forwarding { + // When transitioning into an IPv6 router, host-only state (NDP discovered + // routers, discovered on-link prefixes, and auto-generated addresses) is + // cleaned up/invalidated and NDP router solicitations are stopped. + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(true /* hostOnly */) + } else { + // When transitioning into an IPv6 host, NDP router solicitations are + // started. + e.mu.ndp.startSolicitingRouters() + } +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + // + // Also auto-generate an IPv6 link-local address based on the endpoint's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the endpoint + // was last enabled, other devices may have acquired the same addresses. + var err *tcpip.Error + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + addr := addressEndpoint.AddressWithPrefix().Address + if !header.IsV6UnicastAddress(addr) { + return true + } + + switch addressEndpoint.GetKind() { + case stack.Permanent: + addressEndpoint.SetKind(stack.PermanentTentative) + fallthrough + case stack.PermanentTentative: + err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint) + return err == nil + default: + return true + } + }) + if err != nil { + return err + } + + // Do not auto-generate an IPv6 link-local address for loopback devices. + if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) + } + + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyway. + // + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !e.protocol.Forwarding() { + e.mu.ndp.startSolicitingRouters() + } + + return nil +} + +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() } -// ID returns the ipv6 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(false /* hostOnly */) + e.stopDADForPermanentAddressesLocked() + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) + } } -// PrefixLen returns the ipv6 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen +// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) stopDADForPermanentAddressesLocked() { + // Stop DAD for all the tentative unicast addresses. + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.GetKind() != stack.PermanentTentative { + return true + } + + addr := addressEndpoint.AddressWithPrefix().Address + if header.IsV6UnicastAddress(addr) { + e.mu.ndp.stopDuplicateAddressDetection(addr) + } + + return true + }) } -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() +// DefaultTTL is the default hop limit for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return e.protocol.DefaultTTL() +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and @@ -101,9 +365,9 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -112,25 +376,46 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - return ip + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pkt, params) + + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() + return nil + } + + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) loopedR.Release() } @@ -138,8 +423,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) + return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -152,13 +441,57 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pb := pkts.Front(); pb != nil; pb = pb.Next() { - ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) - pb.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pb, params) + } + + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } + return n, err + } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + + // Slow path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet @@ -171,23 +504,47 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv6(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + if !e.isEnabled() { + return + } + + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + // As per RFC 4291 section 2.7: + // Multicast addresses must not be used as source addresses in IPv6 + // packets or appear in any Routing header. + if header.IsV6MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. - vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() - vv.AppendView(pkt.TransportHeader) + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) vv.Append(pkt.Data) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false - for firstHeader := true; ; firstHeader = false { + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() + return + } + + for { + // Keep track of the start of the previous header so we can report the + // special case of a Hop by Hop at a location other than at the start. + previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -201,11 +558,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 - // (unrecognized next header) error in response to an extension header's - // Next Header field with the Hop By Hop extension header identifier. - if !firstHeader { + if previousHeaderStart != 0 { + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + }, pkt) return } @@ -227,13 +584,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) @@ -244,16 +613,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.4, if a node encounters a routing header with // an unrecognized routing type value, with a non-zero Segments Left // value, the node must discard the packet and send an ICMP Parameter - // Problem, Code 0. If the Segments Left is 0, the node must ignore the - // Routing extension header and process the next header in the packet. + // Problem, Code 0 to the packet's Source Address, pointing to the + // unrecognized Routing Type. + // + // If the Segments Left is 0, the node must ignore the Routing extension + // header and process the next header in the packet. // // Note, the stack does not yet handle any type of routing extension // header, so we just make sure Segments Left is zero before processing // the next extension header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for - // unrecognized routing types with a non-zero Segments Left value. if extHdr.SegmentsLeft() != 0 { + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: it.ParseOffset(), + }, pkt) return } @@ -286,7 +659,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() return } if done { @@ -329,32 +702,44 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - last := start + uint16(fragmentPayloadLen) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < start { + // Drop the fragment if the size of the reassembled payload would exceed + // the maximum payload size. + if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } - var ready bool // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf) + data, proto, ready, err := e.protocol.fragmentation.Process( + // IPv6 ignores the Protocol field since the ID only needs to be unique + // across source-destination pairs, as per RFC 8200 section 4.5. + fragmentation.FragmentID{ + Source: h.SourceAddress(), + Destination: h.DestinationAddress(), + ID: extHdr.ID(), + }, + start, + start+uint16(fragmentPayloadLen)-1, + extHdr.More(), + uint8(rawPayload.Identifier), + rawPayload.Buf, + ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } + pkt.Data = data if ready { // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC - // 8200 section 4.5. - it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data) + // 8200 section 4.5. We also use the NextHeader value from the first + // fragment. + it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data) } case header.IPv6DestinationOptionsExtHdr: @@ -376,13 +761,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) @@ -398,24 +795,58 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(len(pkt.TransportHeader)) + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf + r.Stats().IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. - e.dispatcher.DeliverTransportPacket(r, p, pkt) + switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + case stack.TransportPacketHandled: + case stack.TransportPacketDestinationPortUnreachable: + // As per RFC 4443 section 3.1: + // A destination node SHOULD originate a Destination Unreachable + // message with Code 4 in response to a packet for which the + // transport protocol (e.g., UDP) has no listener, if that transport + // protocol has no alternative means to inform the sender. + _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC 8200 section 4. (page 7): + // Extension headers are numbered from IANA IP Protocol Numbers + // [IANA-PN], the same values used for IPv4 and IPv6. When + // processing a sequence of Next Header values in a packet, the + // first one that is not an extension header [IANA-EH] indicates + // that the next item in the packet is the corresponding upper-layer + // header. + // With more related information on page 8: + // If, as a result of processing a header, the destination node is + // required to proceed to the next header but the Next Header value + // in the current header is unrecognized by the node, it should + // discard the packet and send an ICMP Parameter Problem message to + // the source of the packet, with an ICMP Code value of 1 + // ("unrecognized Next Header type encountered") and the ICMP + // Pointer field containing the offset of the unrecognized value + // within the original packet. + // + // Which when taken together indicate that an unknown protocol should + // be treated as an unrecognized next header value. + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) + default: + panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) + } } default: - // If we receive a packet for an extension header we do not yet handle, - // drop the packet for now. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) r.Stats().UnknownProtocolRcvdPackets.Increment() return } @@ -423,18 +854,340 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } // Close cleans up resources associated with the endpoint. -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + e.disableLocked() + e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */) + e.stopDADForPermanentAddressesLocked() + e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() + + e.protocol.forgetEndpoint(e) +} // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + // TODO(b/169350103): add checks here after making sure we no longer receive + // an empty address. + e.mu.Lock() + defer e.mu.Unlock() + return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) +} + +// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but +// with locking requirements. +// +// addAndAcquirePermanentAddressLocked also joins the passed address's +// solicited-node multicast group and start duplicate address detection. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err != nil { + return nil, err + } + + if !header.IsV6UnicastAddress(addr.Address) { + return addressEndpoint, nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { + return nil, err + } + + addressEndpoint.SetKind(stack.PermanentTentative) + + if e.Enabled() { + if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil { + return nil, err + } + } + + return addressEndpoint, nil +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + return e.removePermanentEndpointLocked(addressEndpoint, true) +} + +// removePermanentEndpointLocked is like removePermanentAddressLocked except +// it works with a stack.AddressEndpoint. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { + addr := addressEndpoint.AddressWithPrefix() + unicast := header.IsV6UnicastAddress(addr.Address) + if unicast { + e.mu.ndp.stopDuplicateAddressDetection(addr.Address) + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + switch addressEndpoint.ConfigType() { + case stack.AddressConfigSlaac: + e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case stack.AddressConfigSlaacTemp: + e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + } + } + + if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil { + return err + } + + if !unicast { + return nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil +} + +// hasPermanentAddressLocked returns true if the endpoint has a permanent +// address equal to the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return false + } + return addressEndpoint.GetKind().IsPermanent() +} + +// getAddressRLocked returns the endpoint for the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { + return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) +} + +// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with +// locking requirements. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB) +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + // addrCandidate is a candidate for Source Address Selection, as per + // RFC 6724 section 5. + type addrCandidate struct { + addressEndpoint stack.AddressEndpoint + scope header.IPv6AddressScope + } + + if len(remoteAddr) == 0 { + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + var cs []addrCandidate + e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !addressEndpoint.IsAssigned(allowExpired) { + return + } + + addr := addressEndpoint.AddressWithPrefix().Address + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) + } + + cs = append(cs, addrCandidate{ + addressEndpoint: addressEndpoint, + scope: scope, + }) + }) + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return true + } + if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { + return saTemp + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if c.addressEndpoint.IncRef() { + return c.addressEndpoint + } + } + + return nil +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV6MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) + type protocol struct { + stack *stack.Stack + + mu struct { + sync.RWMutex + + eps map[*endpoint]struct{} + } + // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. defaultTTL uint32 + + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + + fragmentation *fragmentation.Fragmentation + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. + ndpConfigs NDPConfigurations + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + + // autoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -459,24 +1212,43 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - return &endpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, + nud: nud, dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, - }, nil + } + e.mu.addressableEndpointState.Init(e) + e.mu.ndp = ndpState{ + ep: e, + configs: p.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), + } + e.mu.ndp.initializeTempAddrState() + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.eps[e] = struct{}{} + return e +} + +func (p *protocol) forgetEndpoint(e *endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, e) } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case tcpip.DefaultTTLOption: - p.SetDefaultTTL(uint8(v)) + case *tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(*v)) return nil default: return tcpip.ErrUnknownProtocolOption @@ -484,7 +1256,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) @@ -510,77 +1282,43 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return 0, false, false } - ipHdr := header.IPv6(hdr) - // dataClone consists of: - // - Any IPv6 header bytes after the first 40 (i.e. extensions). - // - The transport header, if present. - // - Any other payload data. - views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + return proto, !fragMore && fragOffset == 0, true +} - // Iterate over the IPv6 extensions to find their length. - // - // Parsing occurs again in HandlePacket because we don't track the - // extensions in PacketBuffer. Unfortunately, that means HandlePacket - // has to do the parsing work again. - var nextHdr tcpip.TransportProtocolNumber - foundNext := true - extensionsSize := 0 -traverseExtensions: - for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { - if err != nil { - break - } - // If we exhaust the extension list, the entire packet is the IPv6 header - // and (possibly) extensions. - if done { - extensionsSize = dataClone.Size() - foundNext = false - break - } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} - switch extHdr := extHdr.(type) { - case header.IPv6FragmentExtHdr: - // If this is an atomic fragment, we don't have to treat it specially. - if !extHdr.More() && extHdr.FragmentOffset() == 0 { - continue - } - // This is a non-atomic fragment and has to be re-assembled before we can - // examine the payload for a transport header. - foundNext = false +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.SwapUint32(&p.forwarding, 1) == 0 + } + return atomic.SwapUint32(&p.forwarding, 0) == 1 +} - case header.IPv6RawPayloadHeader: - // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() - nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) - break traverseExtensions +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + p.mu.Lock() + defer p.mu.Unlock() - default: - // Any other extension is a no-op, keep looping until we find the payload. - } + if !p.setForwarding(v) { + return } - // Put the IPv6 header with extensions in pkt.NetworkHeader. - hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) - if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + for ep := range p.mu.eps { + ep.transitionForwarding(v) } - ipHdr = header.IPv6(hdr) - - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) - - return nextHdr, foundNext, true } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -593,7 +1331,69 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -// NewProtocol returns an IPv6 network protocol. -func NewProtocol() stack.NetworkProtocol { - return &protocol{defaultTTL: DefaultTTL} +// Options holds options to configure a new protocol. +type Options struct { + // NDPConfigs is the default NDP configurations used by interfaces. + NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // + // Note, setting this to true does not mean that a link-local address is + // assigned right away, or at all. If Duplicate Address Detection is enabled, + // an address is only assigned if it successfully resolves. If it fails, no + // further attempts are made to auto-generate an IPv6 link-local adddress. + // + // The generated link-local address follows RFC 4291 Appendix A guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte +} + +// NewProtocolWithOptions returns an IPv6 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { + opts.NDPConfigs.validate() + + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + + ndpDisp: opts.NDPDisp, + ndpConfigs: opts.NDPConfigs, + opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + } + p.mu.eps = make(map[*endpoint]struct{}) + p.SetDefaultTTL(DefaultTTL) + return p + } +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 213ff64f2..d7f82973b 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,13 +15,16 @@ package ipv6 import ( + "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -65,9 +68,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -123,9 +126,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stat := s.Stats().UDP.PacketsReceived @@ -139,18 +142,18 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -172,11 +175,11 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) @@ -184,8 +187,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -271,7 +274,7 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -299,11 +302,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { name string extHdr func(nextHdr uint8) ([]byte, uint8) shouldAccept bool + // Should we expect an ICMP response and if so, with what contents? + expectICMP bool + ICMPType header.ICMPv6Type + ICMPCode header.ICMPv6Code + pointer uint32 + multicast bool }{ { name: "None", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, shouldAccept: true, + expectICMP: false, }, { name: "hopbyhop with unknown option skippable action", @@ -334,9 +344,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action", + name: "hopbyhop with unknown option discard and send icmp action (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -346,12 +377,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -362,39 +399,97 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + multicast: true, shouldAccept: false, + expectICMP: false, }, { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: true, }, { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 1, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6ErroneousHeader, + pointer: header.IPv6FixedHeaderSize + 2, }, { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 0, 0, 0, 0, + }, fragmentExtHdrID + }, shouldAccept: true, }, { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: true, + expectICMP: false, }, { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: false, + expectICMP: false, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{}, + noNextHdrID + }, shouldAccept: false, + expectICMP: false, }, { name: "destination with unknown option skippable action", @@ -410,6 +505,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: true, + expectICMP: false, }, { name: "destination with unknown option discard action", @@ -425,9 +521,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "destination with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. + }, destinationExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action", + name: "destination with unknown option discard and send icmp action (muilticast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -437,12 +554,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. }, destinationExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action unless multicast dest", + name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -453,22 +576,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "routing - atomic fragment", + name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + nextHdr, 1, - // Fragment extension header. - nextHdr, 0, 0, 0, 1, 2, 3, 4, - }, routingExtHdrID + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. + }, destinationExtHdrID }, - shouldAccept: true, + shouldAccept: false, + expectICMP: false, + multicast: true, }, { name: "atomic fragment - routing", @@ -502,12 +636,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { return []byte{ // Routing extension header. hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. // Hop By Hop extension header with skippable unknown option. nextHdr, 0, 62, 4, 1, 2, 3, 4, }, routingExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { + name: "routing - hop by hop (with send icmp unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. + + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, routingExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, }, { name: "No next header", @@ -551,6 +715,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", @@ -571,16 +736,17 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -588,6 +754,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + // Add a default route so that a return packet knows where to go. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -629,17 +803,21 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, HopLimit: 255, SrcAddr: addr1, - DstAddr: addr2, + DstAddr: dstAddr, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().UDP.PacketsReceived @@ -648,6 +826,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 0", got) } + if !test.expectICMP { + if p, ok := e.Read(); ok { + t.Fatalf("unexpected packet received: %#v", p) + } + return + } + + // ICMP required. + p, ok := e.Read() + if !ok { + t.Fatalf("expected packet wasn't written out") + } + + // Pack the output packet into a single buffer.View as the checkers + // assume that. + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { + t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) + } + + ipHdr := header.IPv6(pkt) + checker.IPv6(t, ipHdr, checker.ICMPv6( + checker.ICMPv6Type(test.ICMPType), + checker.ICMPv6Code(test.ICMPCode))) + + // We know we are looking at no extension headers in the error ICMP + // packets. + icmpPkt := header.ICMPv6(ipHdr.Payload()) + // We know we sent small packets that won't be truncated when reflected + // back to us. + originalPacket := icmpPkt.Payload() + if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { + t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) + } + if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { + t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) + } return } @@ -673,20 +889,28 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // fragmentData holds the IPv6 payload for a fragmented IPv6 packet. type fragmentData struct { + srcAddr tcpip.Address + dstAddr tcpip.Address nextHdr uint8 data buffer.VectorisedView } func TestReceiveIPv6Fragments(t *testing.T) { - const nicID = 1 - const udpPayload1Length = 256 - const udpPayload2Length = 128 - const fragmentExtHdrLen = 8 - // Note, not all routing extension headers will be 8 bytes but this test - // uses 8 byte routing extension headers for most sub tests. - const routingExtHdrLen = 8 - - udpGen := func(payload []byte, multiplier uint8) buffer.View { + const ( + nicID = 1 + udpPayload1Length = 256 + udpPayload2Length = 128 + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 8200 section 4.5). + udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize + fragmentExtHdrLen = 8 + // Note, not all routing extension headers will be 8 bytes but this test + // uses 8 byte routing extension headers for most sub tests. + routingExtHdrLen = 8 + ) + + udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View { payloadLen := len(payload) for i := 0; i < payloadLen; i++ { payload[i] = uint8(i) * multiplier @@ -702,19 +926,31 @@ func TestReceiveIPv6Fragments(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) sum = header.Checksum(payload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) return hdr.View() } - var udpPayload1Buf [udpPayload1Length]byte - udpPayload1 := udpPayload1Buf[:] - ipv6Payload1 := udpGen(udpPayload1, 1) + var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte + udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:] + ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2) + + var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte + udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:] + ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2) + + var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte + udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:] + ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2) - var udpPayload2Buf [udpPayload2Length]byte - udpPayload2 := udpPayload2Buf[:] - ipv6Payload2 := udpGen(udpPayload2, 2) + var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte + udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] + ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) + + var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte + udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] + ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) tests := []struct { name string @@ -726,34 +962,60 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "No fragmentation", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: uint8(header.UDPProtocolNumber), - data: ipv6Payload1.ToVectorisedView(), + data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Atomic fragment", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload1Addr1ToAddr2, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Atomic fragment with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1), + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), []buffer.View{ // Fragment extension header. buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), - ipv6Payload1, + ipv6Payload3Addr1ToAddr2, }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, }, { name: "Two fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -763,31 +1025,189 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments out of order", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with different Next Header values", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + // NextHeader value is different than the one in the first fragment, so + // this NextHeader should be ignored. + buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[:63], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[63:], + }, + ), + }, + }, + expectedPayloads: nil, }, { name: "Two fragments with different IDs", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -797,21 +1217,23 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 2 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, @@ -819,9 +1241,49 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: nil, }, { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:65520], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[65520:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -836,14 +1298,16 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Routing extension header. // @@ -855,17 +1319,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Two fragments with per-fragment routing header with non-zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -880,14 +1346,16 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Routing extension header. // @@ -899,7 +1367,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 9, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, @@ -910,6 +1378,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with routing header with zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -924,31 +1394,35 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Segments left = 0. buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Two fragments with routing header with non-zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -963,21 +1437,23 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Segments left = 1. buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, @@ -988,6 +1464,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with routing header with zero segments left across fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is fragmentExtHdrLen+8 because the @@ -1008,12 +1486,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1), + fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. // @@ -1023,7 +1503,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header (part 2) buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - ipv6Payload1, + ipv6Payload1Addr1ToAddr2, }, ), }, @@ -1034,6 +1514,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with routing header with non-zero segments left across fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is fragmentExtHdrLen+8 because the @@ -1054,12 +1536,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1), + fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. // @@ -1069,7 +1553,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header (part 2) buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - ipv6Payload1, + ipv6Payload1Addr1ToAddr2, }, ), }, @@ -1082,6 +1566,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with atomic", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -1091,47 +1577,53 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, // This fragment has the same ID as the other fragments but is an atomic // fragment. It should not interfere with the other fragments. { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2), + fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2), []buffer.View{ // Fragment extension header. // // Fragment offset = 0, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), - ipv6Payload2, + ipv6Payload2Addr1ToAddr2, }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload2, udpPayload1}, + expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2}, }, { name: "Two interleaved fragmented packets", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -1141,11 +1633,13 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+32, @@ -1155,48 +1649,122 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 2 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), - ipv6Payload2[:32], + ipv6Payload2Addr1ToAddr2[:32], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2)-32, + fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32, []buffer.View{ // Fragment extension header. // // Fragment offset = 4, More = false, ID = 2 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), - ipv6Payload2[32:], + ipv6Payload2Addr1ToAddr2[32:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets from different sources but with same ID", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr3, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr3ToAddr2[:32], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr3, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 4, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), + + ipv6Payload1Addr3ToAddr2[32:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -1231,16 +1799,16 @@ func TestReceiveIPv6Fragments(t *testing.T) { PayloadLength: uint16(f.data.Size()), NextHeader: f.nextHdr, HopLimit: 255, - SrcAddr: addr1, - DstAddr: addr2, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { @@ -1263,3 +1831,308 @@ func TestReceiveIPv6Fragments(t *testing.T) { }) } } + +func TestInvalidIPv6Fragments(t *testing.T) { + const ( + nicID = 1 + fragmentExtHdrLen = 8 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + tests := []struct { + name string + fragments []fragmentData + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + name: "fragments reassembled into a payload exceeding the max IPv6 payload size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, + []buffer.View{ + // Fragment extension header. + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, + ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, + 0, 0, 0, 1}), + // Payload length = 16 + payloadGen(16), + }, + ), + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + }) + e := channel.New(0, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + + // Serialize IPv6 fixed header. + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(f.data.Size()), + NextHeader: f.nextHdr, + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, + }) + + vv := hdr.View().ToVectorisedView() + vv.Append(f.data) + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { + t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { + t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) + } + }) + } +} + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } + const ( + src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ) + if err := s.AddAddress(1, ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + } + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + if err != nil { + t.Fatalf("NewSubnet(_, _) failed: %v", err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) + } + return rt +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if !hasEP { + t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) + } + } + + ep.Close() + + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) + } + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index e28c23d66..48a4c65e3 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +package ipv6 import ( "fmt" @@ -23,9 +23,27 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor // Solicitation messages to send when doing Duplicate Address Detection // for a tentative address. @@ -33,14 +51,8 @@ const ( // Default = 1 (from RFC 4862 section 5.1) defaultDupAddrDetectTransmits = 1 - // defaultRetransmitTimer is the default amount of time to wait between - // sending NDP Neighbor solicitation messages. - // - // Default = 1s (from RFC 4861 section 10). - defaultRetransmitTimer = time.Second - // defaultMaxRtrSolicitations is the default number of Router - // Solicitation messages to send when a NIC becomes enabled. + // Solicitation messages to send when an IPv6 endpoint becomes enabled. // // Default = 3 (from RFC 4861 section 10). defaultMaxRtrSolicitations = 3 @@ -79,16 +91,6 @@ const ( // Default = true. defaultAutoGenGlobalAddresses = true - // minimumRetransmitTimer is the minimum amount of time to wait between - // sending NDP Neighbor solicitation messages. Note, RFC 4861 does - // not impose a minimum Retransmit Timer, but we do here to make sure - // the messages are not sent all at once. We also come to this value - // because in the RetransmitTimer field of a Router Advertisement, a - // value of 0 means unspecified, so the smallest valid value is 1. - // Note, the unit of the RetransmitTimer field in the Router - // Advertisement is milliseconds. - minimumRetransmitTimer = time.Millisecond - // minimumRtrSolicitationInterval is the minimum amount of time to wait // between sending Router Solicitation messages. This limit is imposed // to make sure that Router Solicitation messages are not sent all at @@ -147,7 +149,7 @@ const ( minRegenAdvanceDuration = time.Duration(0) // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt - // SLAAC address regenerations in response to a NIC-local conflict. + // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. maxSLAACAddrLocalRegenAttempts = 10 ) @@ -179,7 +181,7 @@ var ( // This is exported as a variable (instead of a constant) so tests // can update it to a smaller value. // - // This value guarantees that a temporary address will be preferred for at + // This value guarantees that a temporary address is preferred for at // least 1hr if the SLAAC prefix is valid for at least that time. MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour @@ -189,11 +191,17 @@ var ( // This is exported as a variable (instead of a constant) so tests // can update it to a smaller value. // - // This value guarantees that a temporary address will be valid for at least + // This value guarantees that a temporary address is valid for at least // 2hrs if the SLAAC prefix is valid for at least that time. MinMaxTempAddrValidLifetime = 2 * time.Hour ) +// NDPEndpoint is an endpoint that supports NDP. +type NDPEndpoint interface { + // SetNDPConfigurations sets the NDP configurations. + SetNDPConfigurations(NDPConfigurations) +} + // DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an // NDP Router Advertisement informed the Stack about. type DHCPv6ConfigurationFromNDPRA int @@ -208,7 +216,7 @@ const ( // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. // // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 - // will return all available configuration information. + // returns all available configuration information when serving addresses. DHCPv6ManagedAddress // DHCPv6OtherConfigurations indicates that other configuration information is @@ -223,19 +231,18 @@ const ( // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { - // OnDuplicateAddressDetectionStatus will be called when the DAD process - // for an address (addr) on a NIC (with ID nicID) completes. resolved - // will be set to true if DAD completed successfully (no duplicate addr - // detected); false otherwise (addr was detected to be a duplicate on - // the link the NIC is a part of, or it was stopped for some other - // reason, such as the address being removed). If an error occured - // during DAD, err will be set and resolved must be ignored. + // OnDuplicateAddressDetectionStatus is called when the DAD process for an + // address (addr) on a NIC (with ID nicID) completes. resolved is set to true + // if DAD completed successfully (no duplicate addr detected); false otherwise + // (addr was detected to be a duplicate on the link the NIC is a part of, or + // it was stopped for some other reason, such as the address being removed). + // If an error occured during DAD, err is set and resolved must be ignored. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) - // OnDefaultRouterDiscovered will be called when a new default router is + // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered // router should be remembered. // @@ -243,56 +250,55 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool - // OnDefaultRouterInvalidated will be called when a discovered default - // router that was remembered is invalidated. + // OnDefaultRouterInvalidated is called when a discovered default router that + // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) - // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true if the newly discovered - // on-link prefix should be remembered. + // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. + // Implementations must return true if the newly discovered on-link prefix + // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool - // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix that was remembered is invalidated. + // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that + // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) - // OnAutoGenAddress will be called when a new prefix with its - // autonomous address-configuration flag set has been received and SLAAC - // has been performed. Implementations may prevent the stack from - // assigning the address to the NIC by returning false. + // OnAutoGenAddress is called when a new prefix with its autonomous address- + // configuration flag set is received and SLAAC was performed. Implementations + // may prevent the stack from assigning the address to the NIC by returning + // false. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool - // OnAutoGenAddressDeprecated will be called when an auto-generated - // address (as part of SLAAC) has been deprecated, but is still - // considered valid. Note, if an address is invalidated at the same - // time it is deprecated, the deprecation event MAY be omitted. + // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) + // is deprecated, but is still considered valid. Note, if an address is + // invalidated at the same ime it is deprecated, the deprecation event may not + // be received. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) - // OnAutoGenAddressInvalidated will be called when an auto-generated - // address (as part of SLAAC) has been invalidated. + // OnAutoGenAddressInvalidated is called when an auto-generated address + // (SLAAC) is invalidated. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) - // OnRecursiveDNSServerOption will be called when an NDP option with - // recursive DNS servers has been received. Note, addrs may contain - // link-local addresses. + // OnRecursiveDNSServerOption is called when the stack learns of DNS servers + // through NDP. Note, the addresses may contain link-local addresses. // // It is up to the caller to use the DNS Servers only for their valid // lifetime. OnRecursiveDNSServerOption may be called for new or @@ -304,8 +310,8 @@ type NDPDispatcher interface { // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) - // OnDNSSearchListOption will be called when an NDP option with a DNS - // search list has been received. + // OnDNSSearchListOption is called when the stack learns of DNS search lists + // through NDP. // // It is up to the caller to use the domain names in the search list // for only their valid lifetime. OnDNSSearchListOption may be called @@ -314,8 +320,8 @@ type NDPDispatcher interface { // be increased, decreased or completely invalidated when lifetime = 0. OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) - // OnDHCPv6Configuration will be called with an updated configuration that is - // available via DHCPv6 for a specified NIC. + // OnDHCPv6Configuration is called with an updated configuration that is + // available via DHCPv6 for the passed NIC. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. @@ -336,7 +342,7 @@ type NDPConfigurations struct { // Must be greater than or equal to 1ms. RetransmitTimer time.Duration - // The number of Router Solicitation messages to send when the NIC + // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. MaxRtrSolicitations uint8 @@ -351,24 +357,22 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements will be - // processed. + // HandleRAs determines whether or not Router Advertisements are processed. HandleRAs bool - // DiscoverDefaultRouters determines whether or not default routers will - // be discovered from Router Advertisements. This configuration is - // ignored if HandleRAs is false. + // DiscoverDefaultRouters determines whether or not default routers are + // discovered from Router Advertisements, as per RFC 4861 section 6. This + // configuration is ignored if HandleRAs is false. DiscoverDefaultRouters bool - // DiscoverOnLinkPrefixes determines whether or not on-link prefixes - // will be discovered from Router Advertisements' Prefix Information - // option. This configuration is ignored if HandleRAs is false. + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are + // discovered from Router Advertisements' Prefix Information option, as per + // RFC 4861 section 6. This configuration is ignored if HandleRAs is false. DiscoverOnLinkPrefixes bool - // AutoGenGlobalAddresses determines whether or not global IPv6 - // addresses will be generated for a NIC in response to receiving a new - // Prefix Information option with its Autonomous Address - // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs + // SLAAC to auto-generate global SLAAC addresses in response to Prefix + // Information options, as per RFC 4862. // // Note, if an address was already generated for some unique prefix, as // part of SLAAC, this option does not affect whether or not the @@ -382,12 +386,12 @@ type NDPConfigurations struct { // // If the method used to generate the address does not support creating // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's - // MAC address), then no attempt will be made to resolve the conflict. + // MAC address), then no attempt is made to resolve the conflict. AutoGenAddressConflictRetries uint8 // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC - // addresses will be generated for a NIC as part of SLAAC privacy extensions, - // RFC 4941. + // addresses are generated for an IPv6 endpoint as part of SLAAC privacy + // extensions, as per RFC 4941. // // Ignored if AutoGenGlobalAddresses is false. AutoGenTempGlobalAddresses bool @@ -426,7 +430,7 @@ func DefaultNDPConfigurations() NDPConfigurations { } // validate modifies an NDPConfigurations with valid values. If invalid values -// are present in c, the corresponding default values will be used instead. +// are present in c, the corresponding default values are used instead. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer @@ -455,8 +459,8 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { - // The NIC this ndpState is for. - nic *NIC + // The IPv6 endpoint this ndpState is for. + ep *endpoint // configs is the per-interface NDP configurations. configs NDPConfigurations @@ -469,13 +473,13 @@ type ndpState struct { rtrSolicit struct { // The timer used to send the next router solicitation message. - timer *time.Timer + timer tcpip.Timer // Used to let the Router Solicitation timer know that it has been stopped. // // Must only be read from or written to while protected by the lock of - // the NIC this ndpState is associated with. MUST be set when the timer is - // set. + // the IPv6 endpoint this ndpState is associated with. MUST be set when the + // timer is set. done *bool } @@ -503,57 +507,57 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // // Must only be read from or written to while protected by the lock of - // the NIC this dadState is associated with. + // the IPv6 endpoint this dadState is associated with. done *bool } // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - // Timer to invalidate the default router. + // Job to invalidate the default router. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - // Timer to invalidate the on-link prefix. + // Job to invalidate the on-link prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // tempSLAACAddrState holds state associated with a temporary SLAAC address. type tempSLAACAddrState struct { - // Timer to deprecate the temporary SLAAC address. + // Job to deprecate the temporary SLAAC address. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the temporary SLAAC address. + // Job to invalidate the temporary SLAAC address. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job - // Timer to regenerate the temporary SLAAC address. + // Job to regenerate the temporary SLAAC address. // // Must not be nil. - regenTimer *tcpip.CancellableTimer + regenJob *tcpip.Job createdAt time.Time // The address's endpoint. // // Must not be nil. - ref *referencedNetworkEndpoint + addressEndpoint stack.AddressEndpoint // Has a new temporary SLAAC address already been regenerated? regenerated bool @@ -561,15 +565,15 @@ type tempSLAACAddrState struct { // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - // Timer to deprecate the prefix. + // Job to deprecate the prefix. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the prefix. + // Job to invalidate the prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. validUntil time.Time @@ -583,10 +587,10 @@ type slaacPrefixState struct { // // May only be nil when the address is being (re-)generated. Otherwise, // must not be nil as all SLAAC prefixes must have a stable address. - ref *referencedNetworkEndpoint + addressEndpoint stack.AddressEndpoint - // The number of times an address has been generated locally where the NIC - // already had the generated address. + // The number of times an address has been generated locally where the IPv6 + // endpoint already had the generated address. localGenerationFailures uint8 } @@ -594,11 +598,12 @@ type slaacPrefixState struct { tempAddrs map[tcpip.Address]tempSLAACAddrState // The next two fields are used by both stable and temporary addresses - // generated for a SLAAC prefix. This is safe as only 1 address will be - // in the generation and DAD process at any time. That is, no two addresses - // will be generated at the same time for a given SLAAC prefix. + // generated for a SLAAC prefix. This is safe as only 1 address is in the + // generation and DAD process at any time. That is, no two addresses are + // generated at the same time for a given SLAAC prefix. - // The number of times an address has been generated and added to the NIC. + // The number of times an address has been generated and added to the IPv6 + // endpoint. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 @@ -613,16 +618,16 @@ type slaacPrefixState struct { // This function must only be called by IPv6 addresses that are currently // tentative. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { return tcpip.ErrAddressFamilyNotSupported } - if ref.getKind() != permanentTentative { + if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should be marked as tentative since we are starting DAD. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } // Should not attempt to perform DAD on an address that is currently in the @@ -633,45 +638,45 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // existed, we would get an error since we attempted to add a duplicate // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. - // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + // See endpoint.addAddressLocked. + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits if remaining == 0 { - ref.setKind(permanent) + addressEndpoint.SetKind(stack.Permanent) // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } return nil } var done bool - var timer *time.Timer + var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the NIC's lock. This is effectively the same - // as starting a goroutine but we use a timer that fires immediately so we can - // reset it for the next DAD iteration. - timer = time.AfterFunc(0, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { + ndp.ep.mu.Lock() + defer ndp.ep.mu.Unlock() if done { // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the NIC lock and stopped DAD - // before this function obtained the NIC lock. Simply return here and do - // nothing further. + // another goroutine already obtained the IPv6 endpoint lock and stopped + // DAD before this function obtained the NIC lock. Simply return here and + // do nothing further. return } - if ref.getKind() != permanentTentative { + if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } dadDone := remaining == 0 @@ -679,33 +684,34 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref var err *tcpip.Error if !dadDone { // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) // Do not hold the lock when sending packets which may be a long running // task or may block link address resolution. We know this is safe // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the NIC. Note, DAD would be - // stopped if the NIC was disabled or removed, or if the address was - // removed. - ndp.nic.mu.Unlock() - err = ndp.sendDADPacket(addr, ref) - ndp.nic.mu.Lock() + // has been stopped before doing any work with the IPv6 endpoint. Note, + // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if + // the address was removed. + ndp.ep.mu.Unlock() + err = ndp.sendDADPacket(addr, addressEndpoint) + ndp.ep.mu.Lock() + addressEndpoint.DecRef() } if done { // If we reach this point, it means that DAD was stopped after we released - // the NIC's read lock and before we obtained the write lock. + // the IPv6 endpoint's read lock and before we obtained the write lock. return } if dadDone { // DAD has resolved. - ref.setKind(permanent) + addressEndpoint.SetKind(stack.Permanent) } else if err == nil { // DAD is not done and we had no errors when sending the last NDP NS, // schedule the next DAD timer. remaining-- - timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + timer.Reset(ndp.configs.RetransmitTimer) return } @@ -714,16 +720,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // integrator know DAD has completed. delete(ndp.dad, addr) - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } // If DAD resolved for a stable SLAAC address, attempt generation of a // temporary SLAAC address. - if dadDone && ref.configType == slaac { + if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) } }) @@ -738,44 +744,50 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns // addr. // -// addr must be a tentative IPv6 address on ndp's NIC. +// addr must be a tentative IPv6 address on ndp's IPv6 endpoint. // -// The NIC ndp belongs to MUST NOT be locked. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The IPv6 endpoint that ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. + // Since this method is required to be called when the IPv6 endpoint is not + // locked, the NIC could have been disabled or removed by another goroutine. if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { return err } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, - NetworkHeaderParams{ + stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() return err @@ -790,11 +802,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd // such a state forever, unless some other external event resolves the DAD // process (receiving an NA from the true owner of addr, or an NS for addr // (implying another node is attempting to use addr)). It is up to the caller -// of this function to handle such a scenario. Normally, addr will be removed -// from n right after this function returns or the address successfully -// resolved. +// of this function to handle such a scenario. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { dad, ok := ndp.dad[addr] if !ok { @@ -813,30 +823,30 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } // handleRA handles a Router Advertisement message that arrived on the NIC // this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - // Is the NIC configured to handle RAs at all? + // Is the IPv6 endpoint configured to handle RAs at all? // // Currently, the stack does not determine router interface status on a - // per-interface basis; it is a stack-wide configuration, so we check - // stack's forwarding flag to determine if the NIC is a routing - // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + // per-interface basis; it is a protocol-wide configuration, so we check the + // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding + // packets. + if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { return } // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -851,11 +861,11 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { if ndp.dhcpv6Configuration != configuration { ndp.dhcpv6Configuration = configuration - ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration) } } - // Is the NIC configured to discover default routers? + // Is the IPv6 endpoint configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] rl := ra.RouterLifetime() @@ -871,9 +881,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update - // the invalidation timer. - rtr.invalidationTimer.StopLocked() - rtr.invalidationTimer.Reset(rl) + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -893,20 +903,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.nic.stack.ndpDisp == nil { + if ndp.ep.protocol.ndpDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.nic.stack.ndpDisp == nil { + if ndp.ep.protocol.ndpDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -940,7 +950,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // invalidateDefaultRouter invalidates a discovered default router. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr, ok := ndp.defaultRouters[ip] @@ -950,41 +960,41 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.StopLocked() + rtr.invalidationJob.Cancel() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } // rememberDefaultRouter remembers a newly discovered default router with IPv6 // link-local address ip with lifetime rl. // -// The router identified by ip MUST NOT already be known by the NIC. +// The router identified by ip MUST NOT already be known by the IPv6 endpoint. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return } state := defaultRouterState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateDefaultRouter(ip) }), } - state.invalidationTimer.Reset(rl) + state.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = state } @@ -994,28 +1004,28 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The prefix identified by prefix MUST NOT already be known. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return } state := onLinkPrefixState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } if l < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(l) + state.invalidationJob.Schedule(l) } ndp.onLinkPrefixes[prefix] = state @@ -1023,7 +1033,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) // invalidateOnLinkPrefix invalidates a discovered on-link prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { s, ok := ndp.onLinkPrefixes[prefix] @@ -1033,12 +1043,12 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - s.invalidationTimer.StopLocked() + s.invalidationJob.Cancel() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1048,7 +1058,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { // handleOnLinkPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the on-link flag is set. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { prefix := pi.Subnet() prefixState, ok := ndp.onLinkPrefixes[prefix] @@ -1082,14 +1092,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. // - // Update the invalidation timer. + // Update the invalidation job. - prefixState.invalidationTimer.StopLocked() + prefixState.invalidationJob.Cancel() if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, reset the timer to expire after + // Prefix is valid for a finite lifetime, schedule the job to execute after // the new valid lifetime. - prefixState.invalidationTimer.Reset(vl) + prefixState.invalidationJob.Schedule(vl) } ndp.onLinkPrefixes[prefix] = prefixState @@ -1101,7 +1111,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // handleAutonomousPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the autonomous flag is set. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { vl := pi.ValidLifetime() pl := pi.PreferredLifetime() @@ -1137,7 +1147,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform // // pl is the new preferred lifetime. vl is the new valid lifetime. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 @@ -1154,15 +1164,15 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.stableAddr.ref) + ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1184,24 +1194,24 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if !ndp.generateSLAACAddr(prefix, &state) { // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or timers for a prefix we + // further as there is no reason to maintain state or jobs for a prefix we // do not have an address for. return } - // Setup the initial timers to deprecate and invalidate prefix. + // Setup the initial jobs to deprecate and invalidate prefix. if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationTimer.Reset(pl) + state.deprecationJob.Schedule(pl) } if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) + state.invalidationJob.Schedule(vl) state.validUntil = now.Add(vl) } // If the address is assigned (DAD resolved), generate a temporary address. - if state.stableAddr.ref.getKind() == permanent { + if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) @@ -1210,32 +1220,27 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.slaacPrefixes[prefix] = state } -// addSLAACAddr adds a SLAAC address to the NIC. +// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return nil } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { // Informed by the integrator not to add the address. return nil } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr, - } - - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { - panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } - return ref + return addressEndpoint } // generateSLAACAddr generates a SLAAC address for prefix. @@ -1244,10 +1249,10 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo // // Panics if the prefix is not a SLAAC prefix or it already has an address. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.stableAddr.ref; r != nil { - panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix())) } // If we have already reached the maximum address generation attempts for the @@ -1267,11 +1272,11 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, - oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()), dadCounter, oIID.SecretKey, ) @@ -1284,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.nic.linkEP.LinkAddress() + linkAddr := ndp.ep.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } @@ -1303,15 +1308,15 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt PrefixLen: validPrefixLenForAutoGen, } - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { break } state.stableAddr.localGenerationFailures++ } - if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { - state.stableAddr.ref = ref + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true } @@ -1321,10 +1326,9 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // regenerateSLAACAddr regenerates an address for a SLAAC prefix. // -// If generating a new address for the prefix fails, the prefix will be -// invalidated. +// If generating a new address for the prefix fails, the prefix is invalidated. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { state, ok := ndp.slaacPrefixes[prefix] if !ok { @@ -1344,7 +1348,7 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { // generateTempSLAACAddr generates a new temporary SLAAC address. // -// If resetGenAttempts is true, the prefix's generation counter will be reset. +// If resetGenAttempts is true, the prefix's generation counter is reset. // // Returns true if a new address was generated. func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { @@ -1365,7 +1369,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress + stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address now := time.Now() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary @@ -1404,7 +1408,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - // Attempt to generate a new address that is not already assigned to the NIC. + // Attempt to generate a new address that is not already assigned to the IPv6 + // endpoint. var generatedAddr tcpip.AddressWithPrefix for i := 0; ; i++ { // If we were unable to generate an address after the maximum SLAAC address @@ -1414,7 +1419,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { break } } @@ -1422,13 +1427,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary // address with a zero preferred lifetime. The checks above ensure this // so we know the address is not deprecated. - ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) - if ref == nil { + addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */) + if addressEndpoint == nil { return false } state := tempSLAACAddrState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1439,9 +1444,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) } - ndp.deprecateSLAACAddress(tempAddrState.ref) + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1454,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1477,13 +1482,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla prefixState.tempAddrs[generatedAddr.Address] = tempAddrState ndp.slaacPrefixes[prefix] = prefixState }), - createdAt: now, - ref: ref, + createdAt: now, + addressEndpoint: addressEndpoint, } - state.deprecationTimer.Reset(pl) - state.invalidationTimer.Reset(vl) - state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) prefixState.generationAttempts++ prefixState.tempAddrs[generatedAddr.Address] = state @@ -1493,7 +1498,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { state, ok := ndp.slaacPrefixes[prefix] if !ok { @@ -1508,26 +1513,26 @@ func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttemp // // pl is the new preferred lifetime. vl is the new valid lifetime. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) + ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint) } else { - prefixState.stableAddr.ref.deprecated = false + prefixState.stableAddr.addressEndpoint.SetDeprecated(false) } - // If prefix was preferred for some finite lifetime before, stop the - // deprecation timer so it can be reset. - prefixState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() now := time.Now() - // Reset the deprecation timer if prefix has a finite preferred lifetime. + // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { - prefixState.deprecationTimer.Reset(pl) + prefixState.deprecationJob.Schedule(pl) } prefixState.preferredUntil = now.Add(pl) } else { @@ -1546,9 +1551,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not keep a timer - // in this case. - prefixState.invalidationTimer.StopLocked() + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() prefixState.validUntil = time.Time{} } else { var effectiveVl time.Duration @@ -1569,20 +1574,20 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } if effectiveVl != 0 { - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) prefixState.validUntil = now.Add(effectiveVl) } } // If DAD is not yet complete on the stable address, there is no need to do // work with temporary addresses. - if prefixState.stableAddr.ref.getKind() != permanent { + if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent { return } // Note, we do not need to update the entries in the temporary address map - // after updating the timers because the timers are held as pointers. + // after updating the jobs because the jobs are held as pointers. var regenForAddr tcpip.Address allAddressesRegenerated := true for tempAddr, tempAddrState := range prefixState.tempAddrs { @@ -1596,14 +1601,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation timer. + // reset the invalidation job. newValidLifetime := validUntil.Sub(now) if newValidLifetime <= 0 { ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) continue } - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.invalidationTimer.Reset(newValidLifetime) + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary // address is the lower of the preferred lifetime of the stable address or @@ -1616,17 +1621,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer preferred, deprecate it immediately. - // Otherwise, reset the deprecation timer. + // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { - ndp.deprecateSLAACAddress(tempAddrState.ref) + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) } else { - tempAddrState.ref.deprecated = false - tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + tempAddrState.addressEndpoint.SetDeprecated(false) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } - tempAddrState.regenTimer.StopLocked() + tempAddrState.regenJob.Cancel() if tempAddrState.regenerated { } else { allAddressesRegenerated = false @@ -1637,7 +1642,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // immediately after we finish iterating over the temporary addresses. regenForAddr = tempAddr } else { - tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) } } } @@ -1647,8 +1652,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // due to an update in preferred lifetime. // // If each temporay address has already been regenerated, no new temporary - // address will be generated. To ensure continuation of temporary SLAAC - // addresses, we manually try to regenerate an address here. + // address is generated. To ensure continuation of temporary SLAAC addresses, + // we manually try to regenerate an address here. if len(regenForAddr) != 0 || allAddressesRegenerated { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. @@ -1659,57 +1664,58 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } } -// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP -// dispatcher that ref has been deprecated. +// deprecateSLAACAddress marks the address as deprecated and notifies the NDP +// dispatcher that address has been deprecated. // -// deprecateSLAACAddress does nothing if ref is already deprecated. +// deprecateSLAACAddress does nothing if the address is already deprecated. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { - if ref.deprecated { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) { + if addressEndpoint.Deprecated() { return } - ref.deprecated = true - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) + addressEndpoint.SetDeprecated(true) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } // invalidateSLAACPrefix invalidates a SLAAC prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.stableAddr.ref; r != nil { + ndp.cleanupSLAACPrefixResources(prefix, state) + + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) + if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) } } - - ndp.cleanupSLAACPrefixResources(prefix, state) } // cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's // resources. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress { + if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.stableAddr.ref = nil + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil ndp.slaacPrefixes[prefix] = state return } @@ -1717,31 +1723,34 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. // // Panics if the SLAAC prefix is not known. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { // Invalidate all temporary addresses. for tempAddr, tempAddrState := range state.tempAddrs { ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) } - state.stableAddr.ref = nil - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() + if state.stableAddr.addressEndpoint != nil { + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil + } + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) } // invalidateTempSLAACAddr invalidates a temporary SLAAC address. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { // Since we are already invalidating the address, do not invalidate the // address when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) } ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) @@ -1750,10 +1759,10 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary // SLAAC address's resources from ndp. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } if !invalidateAddr { @@ -1775,37 +1784,31 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi } // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// timers and entry. +// jobs and entry. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationTimer.StopLocked() - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.regenTimer.StopLocked() + tempAddrState.addressEndpoint.DecRef() + tempAddrState.addressEndpoint = nil + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } -// cleanupState cleans up ndp's state. -// -// If hostOnly is true, then only host-specific state will be cleaned up. -// -// cleanupState MUST be called with hostOnly set to true when ndp's NIC is -// transitioning from a host to a router. This function will invalidate all -// discovered on-link prefixes, discovered routers, and auto-generated -// addresses. +// removeSLAACAddresses removes all SLAAC addresses. // -// If hostOnly is true, then the link-local auto-generated address will not be -// invalidated as routers are also expected to generate a link-local address. +// If keepLinkLocal is false, the SLAAC generated link-local address is removed. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalPrefixes := 0 + var linkLocalPrefixes int for prefix, state := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. - if hostOnly && prefix == linkLocalSubnet { + if keepLinkLocal && prefix == linkLocalSubnet { linkLocalPrefixes++ continue } @@ -1816,6 +1819,21 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) } +} + +// cleanupState cleans up ndp's state. +// +// If hostOnly is true, then only host-specific state is cleaned up. +// +// This function invalidates all discovered on-link prefixes, discovered +// routers, and auto-generated addresses. +// +// If hostOnly is true, then the link-local auto-generated address aren't +// invalidated as routers are also expected to generate a link-local address. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupState(hostOnly bool) { + ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1839,7 +1857,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // -// The NIC ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. @@ -1860,27 +1878,37 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { - ndp.nic.mu.Lock() + ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { + ndp.ep.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the NIC lock and stopped solicitations. - // Simply return here and do nothing further. - ndp.nic.mu.Unlock() + // goroutine already obtained the IPv6 endpoint lock and stopped + // solicitations. Simply return here and do nothing further. + ndp.ep.mu.Unlock() return } // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) - if ref == nil { - ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) + if addressEndpoint == nil { + // Incase this ends up creating a new temporary address, we need to hold + // onto the endpoint until a route is obtained. If we decrement the + // reference count before obtaing a route, the address's resources would + // be released and attempting to obtain a route after would fail. Once a + // route is obtainted, it is safe to decrement the reference count since + // obtaining a route increments the address's reference count. + addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) } - ndp.nic.mu.Unlock() + ndp.ep.mu.Unlock() - localAddr := ref.ep.ID().LocalAddress - r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + localAddr := addressEndpoint.AddressWithPrefix().Address + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) + addressEndpoint.DecRef() + if err != nil { + return + } defer r.Release() // Route should resolve immediately since @@ -1888,15 +1916,16 @@ func (ndp *ndpState) startSolicitingRouters() { // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. + // Since this method is required to be called when the IPv6 endpoint is + // not locked, the IPv6 endpoint could have been disabled or removed by + // another goroutine. if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { return } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1913,23 +1942,26 @@ func (ndp *ndpState) startSolicitingRouters() { } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) - pkt := header.ICMPv6(hdr.Prepend(payloadSize)) - pkt.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) rs.Options().Serialize(optsSerializer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, - NetworkHeaderParams{ + stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() - log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { @@ -1937,19 +1969,19 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.nic.mu.Lock() + ndp.ep.mu.Lock() if done || remaining == 0 { ndp.rtrSolicit.timer = nil ndp.rtrSolicit.done = nil } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but - // we still reached this point, then we know the NIC + // we still reached this point, then we know the IPv6 endpoint // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } - ndp.nic.mu.Unlock() + ndp.ep.mu.Unlock() }) } @@ -1957,7 +1989,7 @@ func (ndp *ndpState) startSolicitingRouters() { // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // -// The NIC ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { if ndp.rtrSolicit.timer == nil { // Nothing to do. @@ -1973,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() { // initializeTempAddrState initializes state related to temporary SLAAC // addresses. func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 64239ce9a..25464a03a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -17,7 +17,9 @@ package ipv6 import ( "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -30,12 +32,13 @@ import ( // setupStackAndEndpoint creates a stack with a single NIC with a link-local // address llladdr and an IPv6 endpoint to a remote with link-local address // rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { @@ -63,14 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } + t.Cleanup(ep.Close) return s, ep } +var _ NDPDispatcher = (*testNDPDispatcher)(nil) + +// testNDPDispatcher is an NDPDispatcher only allows default router discovery. +type testNDPDispatcher struct { + addr tcpip.Address +} + +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { + t.addr = addr + return true +} + +func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { + t.addr = addr +} + +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { +} + +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return false +} + +func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { +} + +func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { +} + +func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { +} + +func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { + var ndpDisp testNDPDispatcher + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ + NDPDisp: &ndpDisp, + })}, + }) + + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + + ipv6EP := ep.(*endpoint) + ipv6EP.mu.Lock() + ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.Unlock() + + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } + + ndpDisp.addr = "" + ndpEP := ep.(stack.NDPEndpoint) + ndpEP.InvalidateDefaultRouter(lladdr1) + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } +} + // TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -100,7 +183,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -136,9 +219,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -174,6 +257,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { } } +// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests +// that receiving a valid NDP NS message with the Source Link Layer Address +// option results in a new entry in the link address cache for the sender of +// the message. +func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + neighbors, err := s.Neighbors(nicID) + if err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + } + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + } + neighborByAddr[n.Addr] = n + } + + if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 { + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + if !ok { + t.Fatalf("expected a neighbor entry for %q", lladdr1) + } + if neigh.LinkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr) + } + if neigh.State != stack.Stale { + t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale) + } + } else { + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + + if ok { + t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + } + } + }) + } +} + func TestNeighorSolicitationResponse(t *testing.T) { const nicID = 1 nicAddr := lladdr0 @@ -183,6 +383,20 @@ func TestNeighorSolicitationResponse(t *testing.T) { remoteLinkAddr0 := linkAddr1 remoteLinkAddr1 := linkAddr2 + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + tests := []struct { name string nsOpts header.NDPOptionsSerializer @@ -341,86 +555,92 @@ func TestNeighorSolicitationResponse(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(1, 1280, nicLinkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: stackTyp.useNeighborCache, + }) + e := channel.New(1, 1280, nicLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } + // If we expected the NS to be invalid, we have nothing else to check. + return + } - // If we expected the NS to be invalid, we have nothing else to check. - return - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + p, got := e.Read() + if !got { + t.Fatal("expected an NDP NA response") + } - p, got := e.Read() - if !got { - t.Fatal("expected an NDP NA response") - } + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) } - - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) }) } } @@ -461,7 +681,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -497,9 +717,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -535,201 +755,385 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { } } -func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r - } - - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View - if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) - } - - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } - ep.HandlePacket(r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: payload.ToVectorisedView(), - }) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) +// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests +// that receiving a valid NDP NA message with the Target Link Layer Address +// option does not result in a new entry in the neighbor cache for the target +// of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { + const nicID = 1 - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + tests := []struct { + name string + optsBuf []byte + isValid bool }{ { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + isValid: true, }, { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, }, { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, }, { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg + name: "Multiple", + optsBuf: []byte{ + 2, 1, 2, 3, 4, 5, 6, 7, + 2, 1, 2, 3, 4, 5, 6, 8, }, }, } - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code uint8 - valid bool + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + neighbors, err := s.Neighbors(nicID) + if err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + } + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + } + neighborByAddr[n.Addr] = n + } + + if neigh, ok := neighborByAddr[lladdr1]; ok { + t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + } + + if test.isValid { + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNDPValidation(t *testing.T) { + stacks := []struct { + name string + useNeighborCache bool }{ { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, + name: "linkAddrCache", + useNeighborCache: false, }, { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, + name: "neighborCache", + useNeighborCache: true, }, } - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + t.Helper() - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } + return s, ep, r + } - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View + if atomicFragment { + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr + nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + } - if t.Failed() { - t.FailNow() - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload) + len(extensions)), + NextHeader: nextHdr, + HopLimit: hopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } + ep.HandlePacket(r, pkt) + } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } + var sllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool + }{ + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + routerOnly: true, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + extraData: sllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } + + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" } - }) + + t.Run(name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() + + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + // RouterOnlyPacketsReceivedByHost count should initially be 0. + if got := routerOnly.Value(); got != 0 { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + + want = 0 + if test.valid && !isRouter && typ.routerOnly { + // RouterOnlyPacketsReceivedByHost count should have increased. + want = 1 + } + if got := routerOnly.Value(); got != want { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + } + + }) + } + }) + } } }) } + } // TestRouterAdvertValidation tests that when the NIC is configured to handle // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. func TestRouterAdvertValidation(t *testing.T) { + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + tests := []struct { name string src tcpip.Address hopLimit uint8 - code uint8 + code header.ICMPv6Code ndpPayload []byte expectedSuccess bool }{ @@ -846,61 +1250,67 @@ func TestRouterAdvertValidation(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: stackTyp.useNeighborCache, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + } + }) } }) } diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD new file mode 100644 index 000000000..c9e57dc0d --- /dev/null +++ b/pkg/tcpip/network/testutil/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + srcs = [ + "testutil.go", + ], + visibility = [ + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go new file mode 100644 index 000000000..7cc52985e --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil.go @@ -0,0 +1,144 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil defines types and functions used to test Network Layer +// functionality such as IP fragmentation. +package testutil + +import ( + "fmt" + "math/rand" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// MockLinkEndpoint is an endpoint used for testing, it stores packets written +// to it and can mock errors. +type MockLinkEndpoint struct { + // WrittenPackets is where packets written to the endpoint are stored. + WrittenPackets []*stack.PacketBuffer + + mtu uint32 + err *tcpip.Error + allowPackets int +} + +// NewMockLinkEndpoint creates a new MockLinkEndpoint. +// +// err is the error that will be returned once allowPackets packets are written +// to the endpoint. +func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { + return &MockLinkEndpoint{ + mtu: mtu, + err: err, + allowPackets: allowPackets, + } +} + +// MTU implements LinkEndpoint.MTU. +func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var n int + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + return n, err + } + n++ + } + + return n, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*MockLinkEndpoint) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*MockLinkEndpoint) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// MakeRandPkt generates a randomized packet. transportHeaderLength indicates +// how many random bytes will be copied in the Transport Header. +// extraHeaderReserveLength indicates how much extra space will be reserved for +// the other headers. The payload is made from Views of the sizes listed in +// viewSizes. +func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { + var views buffer.VectorisedView + + for _, s := range viewSizes { + newView := buffer.NewView(s) + if _, err := rand.Read(newView); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + views.AppendView(newView) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, + Data: views, + }) + pkt.NetworkProtocolNumber = proto + if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt +} diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index f6d592eb5..d87193650 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -400,7 +400,11 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { +// +// An optional testPort closure can be passed in which if provided will be used +// to test if the picked port can be used. The function should return true if +// the port is safe to use, false otherwise. +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -412,12 +416,23 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } + if testPort != nil && !testPort(port) { + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) + return 0, tcpip.ErrPortInUse + } return port, nil } // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil + if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { + return false, nil + } + if testPort != nil && !testPort(p) { + s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst) + return false, nil + } + return true, nil }) } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 58db5868c..4bc949fd8 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -332,7 +332,7 @@ func TestPortReservation(t *testing.T) { pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) if err != test.want { t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 0ab089208..51d428049 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -127,8 +127,8 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) @@ -182,7 +182,7 @@ func main() { if terr == tcpip.ErrConnectStarted { fmt.Println("Connect is pending...") <-notifyCh - terr = ep.GetSockOpt(tcpip.ErrorOption{}) + terr = ep.LastError() } wq.EventUnregister(&waitEntry) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 9e37cab18..8e0ee1cd7 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -112,8 +112,8 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) @@ -188,7 +188,7 @@ func main() { defer wq.EventUnregister(&waitEntry) for { - n, wq, err := ep.Accept() + n, wq, err := ep.Accept(nil) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index e65c731c2..2eaeab779 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -16,6 +16,18 @@ go_template_instance( ) go_template_instance( + name = "neighbor_entry_list", + out = "neighbor_entry_list.go", + package = "stack", + prefix = "neighborEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*neighborEntry", + "Linker": "*neighborEntry", + }, +) + +go_template_instance( name = "packet_buffer_list", out = "packet_buffer_list.go", package = "stack", @@ -27,20 +39,38 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ + "addressable_endpoint_state.go", "conntrack.go", - "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", - "ndp.go", + "neighbor_cache.go", + "neighbor_entry.go", + "neighbor_entry_list.go", + "neighborstate_string.go", "nic.go", + "nud.go", "packet_buffer.go", "packet_buffer_list.go", "rand.go", @@ -50,6 +80,7 @@ go_library( "stack_global_state.go", "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ @@ -74,7 +105,9 @@ go_test( name = "stack_x_test", size = "medium", srcs = [ + "addressable_endpoint_state_test.go", "ndp_test.go", + "nud_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", @@ -83,19 +116,22 @@ go_test( deps = [ ":stack", "//pkg/rand", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) @@ -105,7 +141,10 @@ go_test( srcs = [ "forwarder_test.go", "linkaddrcache_test.go", + "neighbor_cache_test.go", + "neighbor_entry_test.go", "nic_test.go", + "packet_buffer_test.go", ], library = ":stack", deps = [ @@ -113,6 +152,9 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go new file mode 100644 index 000000000..4d3acab96 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -0,0 +1,753 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) +var _ AddressableEndpoint = (*AddressableEndpointState)(nil) + +// AddressableEndpointState is an implementation of an AddressableEndpoint. +type AddressableEndpointState struct { + networkEndpoint NetworkEndpoint + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + endpoints map[tcpip.Address]*addressState + primary []*addressState + + // groups holds the mapping between group addresses and the number of times + // they have been joined. + groups map[tcpip.Address]uint32 + } +} + +// Init initializes the AddressableEndpointState with networkEndpoint. +// +// Must be called before calling any other function on m. +func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { + a.networkEndpoint = networkEndpoint + + a.mu.Lock() + defer a.mu.Unlock() + a.mu.endpoints = make(map[tcpip.Address]*addressState) + a.mu.groups = make(map[tcpip.Address]uint32) +} + +// ReadOnlyAddressableEndpointState provides read-only access to an +// AddressableEndpointState. +type ReadOnlyAddressableEndpointState struct { + inner *AddressableEndpointState +} + +// AddrOrMatching returns an endpoint for the passed address that is consisdered +// bound to the wrapped AddressableEndpointState. +// +// If addr is an exact match with an existing address, that address is returned. +// Otherwise, f is called with each address and the address that f returns true +// for is returned. +// +// Returns nil of no address matches. +func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + if ep, ok := m.inner.mu.endpoints[addr]; ok { + if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { + return ep + } + } + + for _, ep := range m.inner.mu.endpoints { + if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { + return ep + } + } + + return nil +} + +// Lookup returns the AddressEndpoint for the passed address. +// +// Returns nil if the passed address is not associated with the +// AddressableEndpointState. +func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + ep, ok := m.inner.mu.endpoints[addr] + if !ok { + return nil + } + return ep +} + +// ForEach calls f for each address pair. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + for _, ep := range m.inner.mu.endpoints { + if !f(ep) { + return + } + } +} + +// ForEachPrimaryEndpoint calls f for each primary address. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + for _, ep := range m.inner.mu.primary { + f(ep) + } +} + +// ReadOnly returns a readonly reference to a. +func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { + return ReadOnlyAddressableEndpointState{inner: a} +} + +func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.releaseAddressStateLocked(addrState) +} + +// releaseAddressState removes addrState from s's address state (primary and endpoints list). +// +// Preconditions: a.mu must be write locked. +func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) { + oldPrimary := a.mu.primary + for i, s := range a.mu.primary { + if s == addrState { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + oldPrimary[len(oldPrimary)-1] = nil + break + } + } + delete(a.mu.endpoints, addrState.addr.Address) +} + +// AddAndAcquirePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// AddAndAcquireTemporaryAddress adds a temporary address. +// +// Returns tcpip.ErrDuplicateAddress if the address exists. +// +// The temporary address's endpoint is acquired and returned. +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// addAndAcquireAddressLocked adds, acquires and returns a permanent or +// temporary address. +// +// If the addressable endpoint already has the address in a non-permanent state, +// and addAndAcquireAddressLocked is adding a permanent address, that address is +// promoted in place and its properties set to the properties provided. If the +// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// returned, regardless the kind of address that is being added. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { + // attemptAddToPrimary is false when the address is already in the primary + // address list. + attemptAddToPrimary := true + addrState, ok := a.mu.endpoints[addr.Address] + if ok { + if !permanent { + // We are adding a non-permanent address but the address exists. No need + // to go any further since we can only promote existing temporary/expired + // addresses to permanent. + return nil, tcpip.ErrDuplicateAddress + } + + addrState.mu.Lock() + if addrState.mu.kind.IsPermanent() { + addrState.mu.Unlock() + // We are adding a permanent address but a permanent address already + // exists. + return nil, tcpip.ErrDuplicateAddress + } + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr)) + } + + // We now promote the address. + for i, s := range a.mu.primary { + if s == addrState { + switch peb { + case CanBePrimaryEndpoint: + // The address is already in the primary address list. + attemptAddToPrimary = false + case FirstPrimaryEndpoint: + if i == 0 { + // The address is already first in the primary address list. + attemptAddToPrimary = false + } else { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + } + case NeverPrimaryEndpoint: + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + break + } + } + } + + if addrState == nil { + addrState = &addressState{ + addressableEndpointState: a, + addr: addr, + } + a.mu.endpoints[addr.Address] = addrState + addrState.mu.Lock() + // We never promote an address to temporary - it can only be added as such. + // If we are actaully adding a permanent address, it is promoted below. + addrState.mu.kind = Temporary + } + + // At this point we have an address we are either promoting from an expired or + // temporary address to permanent, promoting an expired address to temporary, + // or we are adding a new temporary or permanent address. + // + // The address MUST be write locked at this point. + defer addrState.mu.Unlock() + + if permanent { + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr)) + } + + // Primary addresses are biased by 1. + addrState.mu.refs++ + addrState.mu.kind = Permanent + } + // Acquire the address before returning it. + addrState.mu.refs++ + addrState.mu.deprecated = deprecated + addrState.mu.configType = configType + + if attemptAddToPrimary { + switch peb { + case NeverPrimaryEndpoint: + case CanBePrimaryEndpoint: + a.mu.primary = append(a.mu.primary, addrState) + case FirstPrimaryEndpoint: + if cap(a.mu.primary) == len(a.mu.primary) { + a.mu.primary = append([]*addressState{addrState}, a.mu.primary...) + } else { + // Shift all the endpoints by 1 to make room for the new address at the + // front. We could have just created a new slice but this saves + // allocations when the slice has capacity for the new address. + primaryCount := len(a.mu.primary) + a.mu.primary = append(a.mu.primary, nil) + if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount { + panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount)) + } + a.mu.primary[0] = addrState + } + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + } + + return addrState, nil +} + +// RemovePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.mu.groups[addr]; ok { + panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) + } + + return a.removePermanentAddressLocked(addr) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + addrState, ok := a.mu.endpoints[addr] + if !ok { + return tcpip.ErrBadLocalAddress + } + + return a.removePermanentEndpointLocked(addrState) +} + +// RemovePermanentEndpoint removes the passed endpoint if it is associated with +// a and permanent. +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { + addrState, ok := ep.(*addressState) + if !ok || addrState.addressableEndpointState != a { + return tcpip.ErrInvalidEndpointState + } + + return a.removePermanentEndpointLocked(addrState) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { + if !addrState.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + addrState.SetKind(PermanentExpired) + a.decAddressRefLocked(addrState) + return nil +} + +// decAddressRef decrements the address's reference count and releases it once +// the reference count hits 0. +func (a *AddressableEndpointState) decAddressRef(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.decAddressRefLocked(addrState) +} + +// decAddressRefLocked is like decAddressRef but with locking requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) { + addrState.mu.Lock() + defer addrState.mu.Unlock() + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr)) + } + + addrState.mu.refs-- + + if addrState.mu.refs != 0 { + return + } + + // A non-expired permanent address must not have its reference count dropped + // to 0. + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind)) + } + + a.releaseAddressStateLocked(addrState) +} + +// MainAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.GetKind() == Permanent + }) + if ep == nil { + return tcpip.AddressWithPrefix{} + } + + addr := ep.AddressWithPrefix() + a.decAddressRefLocked(ep) + return addr +} + +// acquirePrimaryAddressRLocked returns an acquired primary address that is +// valid according to isValid. +// +// Precondition: e.mu must be read locked +func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState { + var deprecatedEndpoint *addressState + for _, ep := range a.mu.primary { + if !isValid(ep) { + continue + } + + if !ep.Deprecated() { + if ep.IncRef() { + // ep is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + a.decAddressRefLocked(deprecatedEndpoint) + deprecatedEndpoint = nil + } + + return ep + } + } else if deprecatedEndpoint == nil && ep.IncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of + // ep in case a doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, ep's reference count + // will be decremented. + deprecatedEndpoint = ep + } + } + + return deprecatedEndpoint +} + +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + a.mu.Lock() + defer a.mu.Unlock() + + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } + + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } + + return addrState + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + if err != nil { + // addAndAcquireAddressLocked only returns an error if the address is + // already assigned but we just checked above if the address exists so we + // expect no error. + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + } + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + return ep +} + +// AcquireOutgoingPrimaryAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.IsAssigned(allowExpired) + }) + + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + + return ep +} + +// PrimaryAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.primary { + // Don't include tentative, expired or temporary endpoints + // to avoid confusion and prevent the caller from using + // those. + switch ep.GetKind() { + case PermanentTentative, PermanentExpired, Temporary: + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// PermanentAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.endpoints { + if !ep.GetKind().IsPermanent() { + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// JoinGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) + if err != nil { + return false, err + } + // We have no need for the address endpoint. + a.decAddressRefLocked(ep) + } + + a.mu.groups[group] = joins + 1 + return !ok, nil +} + +// LeaveGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + return false, tcpip.ErrBadLocalAddress + } + + if joins == 1 { + a.removeGroupAddressLocked(group) + delete(a.mu.groups, group) + return true, nil + } + + a.mu.groups[group] = joins - 1 + return false, nil +} + +// IsInGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { + a.mu.RLock() + defer a.mu.RUnlock() + _, ok := a.mu.groups[group] + return ok +} + +func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { + if err := a.removePermanentAddressLocked(group); err != nil { + // removePermanentEndpointLocked would only return an error if group is + // not bound to the addressable endpoint, but we know it MUST be assigned + // since we have group in our map of groups. + panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) + } +} + +// Cleanup forcefully leaves all groups and removes all permanent addresses. +func (a *AddressableEndpointState) Cleanup() { + a.mu.Lock() + defer a.mu.Unlock() + + for group := range a.mu.groups { + a.removeGroupAddressLocked(group) + } + a.mu.groups = make(map[tcpip.Address]uint32) + + for _, ep := range a.mu.endpoints { + // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // not a permanent address. + if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) + } + } +} + +var _ AddressEndpoint = (*addressState)(nil) + +// addressState holds state for an address. +type addressState struct { + addressableEndpointState *AddressableEndpointState + addr tcpip.AddressWithPrefix + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + refs uint32 + kind AddressKind + configType AddressConfigType + deprecated bool + } +} + +// AddressWithPrefix implements AddressEndpoint. +func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { + return a.addr +} + +// GetKind implements AddressEndpoint. +func (a *addressState) GetKind() AddressKind { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.kind +} + +// SetKind implements AddressEndpoint. +func (a *addressState) SetKind(kind AddressKind) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.kind = kind +} + +// IsAssigned implements AddressEndpoint. +func (a *addressState) IsAssigned(allowExpired bool) bool { + if !a.addressableEndpointState.networkEndpoint.Enabled() { + return false + } + + switch a.GetKind() { + case PermanentTentative: + return false + case PermanentExpired: + return allowExpired + default: + return true + } +} + +// IncRef implements AddressEndpoint. +func (a *addressState) IncRef() bool { + a.mu.Lock() + defer a.mu.Unlock() + if a.mu.refs == 0 { + return false + } + + a.mu.refs++ + return true +} + +// DecRef implements AddressEndpoint. +func (a *addressState) DecRef() { + a.addressableEndpointState.decAddressRef(a) +} + +// ConfigType implements AddressEndpoint. +func (a *addressState) ConfigType() AddressConfigType { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.configType +} + +// SetDeprecated implements AddressEndpoint. +func (a *addressState) SetDeprecated(d bool) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.deprecated = d +} + +// Deprecated implements AddressEndpoint. +func (a *addressState) Deprecated() bool { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.deprecated +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go new file mode 100644 index 000000000..26787d0a3 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// TestAddressableEndpointStateCleanup tests that cleaning up an addressable +// endpoint state removes permanent addresses and leaves groups. +func TestAddressableEndpointStateCleanup(t *testing.T) { + var ep fakeNetworkEndpoint + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + var s stack.AddressableEndpointState + s.Init(&ep) + + addr := tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + } + + { + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + } + // We don't need the address endpoint. + ep.DecRef() + } + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep == nil { + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) + } + ep.DecRef() + } + + group := tcpip.Address("\x02") + if added, err := s.JoinGroup(group); err != nil { + t.Fatalf("s.JoinGroup(%s): %s", group, err) + } else if !added { + t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) + } + if !s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) + } + + s.Cleanup() + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) + } + } + if s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + } +} diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index af9c325ca..0cd1da11f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,9 +15,12 @@ package stack import ( + "encoding/binary" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) @@ -30,6 +33,10 @@ import ( // // Currently, only TCP tracking is supported. +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 + // Direction of the tuple. type direction int @@ -42,13 +49,19 @@ const ( type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. +// +// +stateify savable type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry + tupleID // conn is the connection tracking entry this tuple belongs to. @@ -61,6 +74,8 @@ type tuple struct { // tupleID uniquely identifies a connection in one direction. It currently // contains enough information to distinguish between any TCP or UDP // connection, and will need to be extended to support other protocols. +// +// +stateify savable type tupleID struct { srcAddr tcpip.Address srcPort uint16 @@ -83,6 +98,8 @@ func (ti tupleID) reply() tupleID { } // conn is a tracked connection. +// +// +stateify savable type conn struct { // original is the tuple in original direction. It is immutable. original tuple @@ -97,36 +114,98 @@ type conn struct { // update the state of tcb. It is immutable. tcbHook Hook - // mu protects tcb. - mu sync.Mutex - + // mu protects all mutable state. + mu sync.Mutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} + +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout +} + +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } } // ConnTrack tracks all connections created for NAT rules. Most users are -// expected to only call handlePacket and createConnFor. +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable type ConnTrack struct { - // mu protects conns. - mu sync.RWMutex + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 - // conns maintains a map of tuples needed for connection tracking for - // iptables NAT rules. It is protected by mu. - conns map[tupleID]tuple + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket +} + +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList } // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. +// +// Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader) - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := pkt.Network() + if netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol } @@ -136,15 +215,16 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { dstAddr: netHeader.DestinationAddress(), dstPort: tcpHeader.DestinationPort(), transProto: netHeader.TransportProtocol(), - netProto: header.IPv4ProtocolNumber, + netProto: pkt.NetworkProtocolNumber, }, nil } // newConn creates new connection. func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { conn := conn{ - manip: manip, - tcbHook: hook, + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), } conn.original = tuple{conn: &conn, tupleID: orig} conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} @@ -161,19 +241,35 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { if err != nil { return nil, dirOriginal } + return ct.connForTID(tid) +} - ct.mu.Lock() - defer ct.mu.Unlock() - - tuple, ok := ct.conns[tid] - if !ok { - return nil, dirOriginal +func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } } - return tuple.conn, tuple.direction + + return nil, dirOriginal } -// createConnFor creates a new conn for pkt. -func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -186,8 +282,8 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.MinIP - replyTID.srcPort = rt.MinPort + replyTID.srcAddr = rt.Addr + replyTID.srcPort = rt.Port var manip manipType switch hook { case Prerouting: @@ -196,23 +292,61 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg manip = manipDstOutput } conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} - // Add the changed tuple to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked - // list. - ct.mu.Lock() - defer ct.mu.Unlock() - ct.conns[tid] = conn.original - ct.conns[replyTID] = conn.reply +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() + } - return conn + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } + + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. // TODO(gvisor.dev/issue/170): Change address for Prerouting hook. func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + + netHeader := pkt.Network() + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -228,14 +362,28 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader.SetSourceAddress(conn.original.dstAddr) } - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacketOutput manipulates ports for packets in Output hook. func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + + netHeader := pkt.Network() + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -253,8 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -263,25 +410,39 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + return false } conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. if conn == nil { - // Connection not found for the packet or the packet is invalid. - return + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { + return false } switch hook { @@ -297,35 +458,184 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() - var st tcpconntrack.Result - tcpHeader := header.TCP(pkt.TransportHeader) - if conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return + } + + // We only track TCP connections. + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + return + } + + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { + return + } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + ct.insertConn(conn) +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} + +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval + } + // We've hit the maximum interval. + return idx, maxInterval +} + +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: +// * ct.mu is locked for reading. +// * bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false + } + + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) } + ct.buckets[bucket].tuples.Remove(tuple) - // Delete conn if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConn(conn) + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() } + + return true } -// deleteConn deletes the connection. -func (ct *ConnTrack) deleteConn(conn *conn) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { + // Lookup the connection. The reply's original destination + // describes the original address. + tid := tupleID{ + srcAddr: epID.LocalAddress, + srcPort: epID.LocalPort, + dstAddr: epID.RemoteAddress, + dstPort: epID.RemotePort, + transProto: header.TCPProtocolNumber, + netProto: netProto, + } + conn, _ := ct.connForTID(tid) if conn == nil { - return + // Not a tracked connection. + return "", 0, tcpip.ErrNotConnected + } else if conn.manip == manipNone { + // Unmanipulated connection. + return "", 0, tcpip.ErrInvalidOptionValue } - ct.mu.Lock() - defer ct.mu.Unlock() - - delete(ct.conns, conn.original.tupleID) - delete(ct.conns, conn.reply.tupleID) + return conn.original.dstAddr, conn.original.dstPort, nil } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index a6546cef0..4e4b00a92 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -20,8 +20,10 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -44,37 +46,37 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fwdTestNetworkEndpoint struct { + AddressableEndpointState + nicID tcpip.NICID - id NetworkEndpointID - prefixLen int proto *fwdTestNetworkProtocol dispatcher TransportDispatcher ep LinkEndpoint } -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + +func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { + return nil } -func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (*fwdTestNetworkEndpoint) Enabled() bool { + return true } -func (f *fwdTestNetworkEndpoint) PrefixLen() int { - return f.prefixLen +func (*fwdTestNetworkEndpoint) Disable() {} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { - return &f.id -} - func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -85,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr return 0 } -func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -96,9 +94,9 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] - b[srcAddrOffset] = f.id.LocalAddress[0] + b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) @@ -113,17 +111,28 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu return tcpip.ErrNotSupported } -func (*fwdTestNetworkEndpoint) Close() {} +func (f *fwdTestNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { addrCache *linkAddrCache + neigh *neighborCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) + + mu struct { + sync.RWMutex + forwarding bool + } } +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -141,42 +150,40 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add } func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = netHeader - pkt.Data.TrimFront(fwdTestNetHeaderLen) - return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { - return &fwdTestNetworkEndpoint{ - nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { + e := &fwdTestNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, - }, nil + ep: nic.LinkEndpoint(), + } + e.AddressableEndpointState.Init(e) + return e } -func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func (f *fwdTestNetworkProtocol) Close() {} +func (*fwdTestNetworkProtocol) Close() {} -func (f *fwdTestNetworkProtocol) Wait() {} +func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { - if f.addrCache != nil && f.onLinkAddressResolved != nil { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr) + f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) }) } return nil @@ -189,10 +196,25 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip return "", false } -func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding + +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + // fwdTestPacketInfo holds all the information about an outbound packet. type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress @@ -287,7 +309,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: &PacketBuffer{Data: vv}, + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), } select { @@ -301,16 +323,29 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocol{proto}, + NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + UseNeighborCache: useNeighborCache, }) - proto.addrCache = s.linkAddrCache + if !useNeighborCache { + proto.addrCache = s.linkAddrCache + } // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ @@ -338,6 +373,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f t.Fatal("AddAddress #2 failed:", err) } + if useNeighborCache { + // Control the neighbor cache for NIC 2. + nic, ok := s.nics[2] + if !ok { + t.Fatal("failed to get the neighbor cache for NIC 2") + } + proto.neigh = nic.neigh + } + // Route all packets to NIC 2. { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -351,79 +395,129 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) - var p fwdTestPacketInfo + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } + var p fwdTestPacketInfo - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } func TestForwardingWithFakeResolver(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any address will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - var p fwdTestPacketInfo + var p fwdTestPacketInfo - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } @@ -431,15 +525,17 @@ func TestForwardingWithNoResolver(t *testing.T) { // Create a network protocol without a resolver. proto := &fwdTestNetworkProtocol{} - ep1, ep2 := fwdTestNetFactory(t, proto) + // Whether or not we use the neighbor cache here does not matter since + // neither linkAddrCache nor neighborCache will be used. + ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) select { case <-ep2.C: @@ -449,202 +545,334 @@ func TestForwardingWithNoResolver(t *testing.T) { } func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - } + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) - } + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) - } + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). - if !ok { - t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) - } + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go new file mode 100644 index 000000000..5efddfaaf --- /dev/null +++ b/pkg/tcpip/stack/headertype_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type headerType ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] +} + +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" + +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} + +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 974d77c36..8d6d9a7f1 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,104 +16,186 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) -// Table names. +// tableID is an index into IPTables.tables. +type tableID int + const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" + natID tableID = iota + mangleID + filterID + numTables ) -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +// Table names. const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" + NATTable = "nat" + MangleTable = "mangle" + FilterTable = "filter" ) +// nameToID is immutable. +var nameToID = map[string]tableID{ + NATTable: natID, + MangleTable: mangleID, + FilterTable: filterID, +} + // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. return &IPTables{ - tables: map[string]Table{ - TablenameNat: Table{ + v4Tables: [numTables]Table{ + natID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - Underflows: map[Hook]int{ + Underflows: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - UserChains: map[string]int{}, }, - TablenameMangle: Table{ + mangleID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Output: 1, }, - Underflows: map[Hook]int{ + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + v6Tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ Prerouting: 0, Output: 1, }, - UserChains: map[string]int{}, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, }, - TablenameFilter: Table{ + filterID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, }, - priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + priorities: [NumHooks][]tableID{ + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, connections: ConnTrack{ - conns: make(map[tupleID]tuple), + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -122,62 +204,66 @@ func DefaultTables() *IPTables { func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, } } -// EmptyNatTable returns a Table with no rules and the filter table chains +// EmptyNATTable returns a Table with no rules and the filter table chains // mapped to HookUnset. -func EmptyNatTable() Table { +func EmptyNATTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + BuiltinChains: [NumHooks]int{ + Forward: HookUnset, }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + Underflows: [NumHooks]int{ + Forward: HookUnset, }, - UserChains: map[string]int{}, } } -// GetTable returns table by name. -func (it *IPTables) GetTable(name string) (Table, bool) { +// GetTable returns a table by name. +func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { + id, ok := nameToID[name] + if !ok { + return Table{}, false + } it.mu.RLock() defer it.mu.RUnlock() - t, ok := it.tables[name] - return t, ok + if ipv6 { + return it.v6Tables[id], true + } + return it.v4Tables[id], true } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) { +func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { + id, ok := nameToID[name] + if !ok { + return tcpip.ErrInvalidOptionValue + } it.mu.Lock() defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } it.modified = true - it.tables[name] = table -} - -// GetPriorities returns slice of priorities associated with hook. -func (it *IPTables) GetPriorities(hook Hook) []string { - it.mu.RLock() - defer it.mu.RUnlock() - return it.priorities[hook] + if ipv6 { + it.v6Tables[id] = table + } else { + it.v4Tables[id] = table + } + return nil } // A chainVerdict is what a table decides should be done with a packet. @@ -199,26 +285,43 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// which address and nicName can be gathered. Currently, address is only +// needed for prerouting and nicName is only needed for output. +// // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { + if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { + return true + } // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() + defer it.mu.RUnlock() if !it.modified { - it.mu.RUnlock() return true } - it.mu.RUnlock() // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.GetPriorities(hook) { - table, _ := it.GetTable(tablename) + priorities := it.priorities[hook] + for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } + var table Table + if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber { + table = it.v6Tables[tableID] + } else { + table = it.v4Tables[tableID] + } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -229,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -245,17 +348,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -279,14 +424,14 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { +// Preconditions: +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { case RuleAccept: return chainAccept @@ -303,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -326,25 +471,14 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { +// Preconditions: +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(pkt, hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -363,5 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) +} + +// OriginalDst returns the original destination of redirected connections. It +// returns an error if the connection doesn't exist or isn't redirected. +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { + it.mu.RLock() + defer it.mu.RUnlock() + if !it.modified { + return "", 0, tcpip.ErrNotConnected + } + return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d43f60c67..538c4625d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -21,117 +21,178 @@ import ( ) // AcceptTarget accepts packets. -type AcceptTarget struct{} +type AcceptTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (at *AcceptTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: at.NetworkProtocol, + } +} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } // DropTarget drops packets. -type DropTarget struct{} +type DropTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (dt *DropTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: dt.NetworkProtocol, + } +} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. -type ErrorTarget struct{} +type ErrorTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (et *ErrorTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: et.NetworkProtocol, + } +} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } // UserChainTarget marks a rule as the beginning of a user chain. type UserChainTarget struct { + // Name is the chain name. Name string + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (uc *UserChainTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: uc.NetworkProtocol, + } } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } // ReturnTarget returns from the current chain. If the chain is a built-in, the // hook's underflow should be called. -type ReturnTarget struct{} +type ReturnTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (rt *ReturnTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: rt.NetworkProtocol, + } +} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const RedirectTargetName = "REDIRECT" + // RedirectTarget redirects the packet by modifying the destination port/IP. -// Min and Max values for IP and Ports in the struct indicate the range of -// values which can be used to redirect. +// TODO(gvisor.dev/issue/170): Other flags need to be added after we support +// them. type RedirectTarget struct { - // TODO(gvisor.dev/issue/170): Other flags need to be added after - // we support them. - // RangeProtoSpecified flag indicates single port is specified to - // redirect. - RangeProtoSpecified bool + // Addr indicates address used to redirect. + Addr tcpip.Address - // MinIP indicates address used to redirect. - MinIP tcpip.Address + // Port indicates port used to redirect. + Port uint16 - // MaxIP indicates address used to redirect. - MaxIP tcpip.Address - - // MinPort indicates port used to redirect. - MinPort uint16 + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} - // MaxPort indicates port used to redirect. - MaxPort uint16 +// ID implements Target.ID. +func (rt *RedirectTarget) ID() TargetID { + return TargetID{ + Name: RedirectTargetName, + NetworkProtocol: rt.NetworkProtocol, + } } // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 } // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1) in Output and - // to primary address of the incoming interface in Prerouting. + // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // the primary address of the incoming interface in Prerouting. switch hook { case Output: - rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) - rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + } else { + rt.Addr = header.IPv6Loopback + } case Prerouting: - rt.MinIP = address - rt.MaxIP = address + rt.Addr = address default: panic("redirect target is supported only on output and prerouting hooks") } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader) - switch protocol := netHeader.TransportProtocol(); protocol { + switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader) - udpHeader.SetDestinationPort(rt.MinPort) + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetDestinationPort(rt.Port) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(protocol, length) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) @@ -140,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - // Change destination address. - netHeader.SetDestinationAddress(rt.MinIP) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + + pkt.Network().SetDestinationAddress(rt.Addr) + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -153,7 +219,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index c528ec381..7b3f3e88b 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "strings" "sync" @@ -78,50 +79,66 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { - // mu protects tables, priorities, and modified. + // mu protects v4Tables, v6Tables, and modified. mu sync.RWMutex - - // tables maps table names to tables. User tables have arbitrary names. - // mu needs to be locked for accessing. - tables map[string]Table - - // priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. mu needs to be locked for accessing. - priorities map[Hook][]string - + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. mu must be locked for accessing. + v4Tables [numTables]Table + v6Tables [numTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. It is immutable. + priorities [NumHooks][]tableID + connections ConnTrack + + // reaperDone can be signaled to stop the reaper goroutine. + reaperDone chan struct{} } -// A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules. +// A Table defines a set of chains and hooks into the network stack. +// +// It is a list of Rules, entry points (BuiltinChains), and error handlers +// (Underflows). As packets traverse netstack, they hit hooks. When a packet +// hits a hook, iptables compares it to Rules starting from that hook's entry +// point. So if a packet hits the Input hook, we look up the corresponding +// entry point in BuiltinChains and jump to that point. +// +// If the Rule doesn't match the packet, iptables continues to the next Rule. +// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept +// or RuleDrop) that causes the packet to stop traversing iptables. It can also +// jump to other rules or perform custom actions based on Rule.Target. +// +// Underflow Rules are invoked when a chain returns without reaching a verdict. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int + BuiltinChains [NumHooks]int // Underflows maps builtin chains to their underflow rule in Rules // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int + Underflows [NumHooks]int } // ValidHooks returns a bitmap of the builtin hooks for the given table. func (table *Table) ValidHooks() uint32 { hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook + for hook, ruleIdx := range table.BuiltinChains { + if ruleIdx != HookUnset { + hooks |= 1 << hook + } } return hooks } @@ -130,6 +147,8 @@ func (table *Table) ValidHooks() uint32 { // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -141,11 +160,18 @@ type Rule struct { Target Target } -// IPHeaderFilter holds basic IP filtering data common to every rule. +// IPHeaderFilter performs basic IP header matching common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber + // CheckProtocol determines whether the Protocol field should be + // checked during matching. + // TODO(gvisor.dev/issue/3549): Check this field during matching. + CheckProtocol bool + // Dst matches the destination IP address. Dst tcpip.Address @@ -182,16 +208,43 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } -// match returns whether hdr matches the filter. -func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. +// match returns whether pkt matches the filter. +// +// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 +// or IPv6 header length. +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { + // Extract header fields. + var ( + // TODO(gvisor.dev/issue/170): Support other filter fields. + transProto tcpip.TransportProtocolNumber + dstAddr tcpip.Address + srcAddr tcpip.Address + ) + switch proto := pkt.NetworkProtocolNumber; proto { + case header.IPv4ProtocolNumber: + hdr := header.IPv4(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + case header.IPv6ProtocolNumber: + hdr := header.IPv6(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + default: + panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto)) + } + // Check the transport protocol. - if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + if fl.CheckProtocol && fl.Protocol != transProto { return false } - // Check the source and destination IPs. - if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + // Check the addresses. + if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) || + !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) { return false } @@ -219,6 +272,18 @@ func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool return true } +// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header +// applies. +func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber { + switch len(fl.Src) { + case header.IPv4AddressSize: + return header.IPv4ProtocolNumber + case header.IPv6AddressSize: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src)) +} + // filterAddress returns whether addr matches the filter. func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { matches := true @@ -244,8 +309,23 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } +// A TargetID uniquely identifies a target. +type TargetID struct { + // Name is the target name as stored in the xt_entry_target struct. + Name string + + // NetworkProtocol is the protocol to which the target applies. + NetworkProtocol tcpip.NetworkProtocolNumber + + // Revision is the version of the target. + Revision uint8 +} + // A Target is the interface for taking an action for a packet. type Target interface { + // ID uniquely identifies the Target. + ID() TargetID + // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 403557fd7..6f73a0ce4 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 1baa498d0..33806340e 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math" "sync/atomic" "testing" "time" @@ -48,7 +49,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) if f := r.onLinkAddressRequest; f != nil { f() @@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) { } func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) + // There is a race condition causing this test to fail when the executor + // takes longer than the resolution timeout to call linkAddrCache.get. This + // is especially common when this test is run with gotsan. + // + // Using a large resolution timeout decreases the probability of experiencing + // this race condition and does not affect how long this test takes to run. + c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) @@ -275,3 +282,71 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } + +// TestCacheWaker verifies that RemoveWaker removes a waker previously added +// through get(). +func TestCacheWaker(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + + // First, sanity check that wakers are working. + { + linkRes := &testLinkAddressResolver{cache: c} + s := sleep.Sleeper{} + defer s.Done() + + const wakerID = 1 + w := sleep.Waker{} + s.AddWaker(&w, wakerID) + + e := testAddrs[0] + + if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) + } + id, ok := s.Fetch(true /* block */) + if !ok { + t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") + } + if id != wakerID { + t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) + } + + if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { + t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) + } else if got != e.linkAddr { + t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) + } + } + + // Check that RemoveWaker works. + { + linkRes := &testLinkAddressResolver{cache: c} + s := sleep.Sleeper{} + defer s.Done() + + const wakerID = 2 // different than the ID used in the sanity check + w := sleep.Waker{} + s.AddWaker(&w, wakerID) + + e := testAddrs[1] + linkRes.onLinkAddressRequest = func() { + // Remove the waker before the linkAddrCache has the opportunity to send + // a notification. + c.removeWaker(e.addr, &w) + } + + if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) + } + + if got, err := getBlocking(c, e.addr, linkRes); err != nil { + t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) + } else if got != e.linkAddr { + t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Fatalf("unexpected notification from waker with id %d", id) + } + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6f86abc98..73a01c2dd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -150,10 +150,10 @@ type ndpDNSSLEvent struct { type ndpDHCPv6Event struct { nicID tcpip.NICID - configuration stack.DHCPv6ConfigurationFromNDPRA + configuration ipv6.DHCPv6ConfigurationFromNDPRA } -var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. @@ -170,7 +170,7 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ @@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } } -// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip. return n.rememberRouter } -// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip } } -// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip return n.rememberPrefix } -// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi } } -// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { if c := n.rdnssC; c != nil { c <- ndpRDNSSEvent{ @@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } -// Implements stack.NDPDispatcher.OnDNSSearchListOption. +// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { if n.dnsslC != nil { n.dnsslC <- ndpDNSSLEvent{ @@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s } } -// Implements stack.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { +// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { c <- ndpDHCPv6Event{ nicID, @@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = test.retransTimer - opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits e := channelLinkWithHeaderLength{ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ipv6.NDPConfigurations{ + RetransmitTimer: test.retransTimer, + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + }, + })}, + }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -541,7 +542,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -550,14 +551,34 @@ func TestDADResolve(t *testing.T) { checker.NDPNSOptions(nil), )) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } }) } } +func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(tgt) + snmc := header.SolicitedNodeAddr(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) +} + // TestDADFail tests to make sure that the DAD process fails if another node is // detected to be performing DAD on the same address (receive an NS message from // a node doing DAD for the same address), or if another node is detected to own @@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) { tests := []struct { name string - makeBuf func(tgt tcpip.Address) buffer.Prependable + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "RxSolicit", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RxSolicit", + rxPkt: rxNDPSolicit, + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, }, { - "RxAdvert", - func(tgt tcpip.Address) buffer.Prependable { + name: "RxAdvert", + rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) @@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) { SrcAddr: tgt, DstAddr: header.IPv6AllNodesMulticastAddress, }) - - return hdr - + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, }, @@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = time.Second * 2 + ndpConfigs := ipv6.DefaultNDPConfigurations() + ndpConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -664,12 +663,8 @@ func TestDADFail(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Receive a packet to simulate multiple nodes owning or - // attempting to own the same address. - hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + // Receive a packet to simulate an address conflict. + test.rxPkt(e, addr1) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -753,18 +748,19 @@ func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.NDPConfigurations{ + + ndpConfigs := ipv6.NDPConfigurations{ RetransmitTimer: time.Second, DupAddrDetectTransmits: 2, } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -814,19 +810,6 @@ func TestDADStop(t *testing.T) { } } -// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NDP configurations using an invalid NICID. -func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - - // No NIC with ID 1 yet. - if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) - } -} - // TestSetNDPConfigurations tests that we can update and use per-interface NDP // configurations without affecting the default NDP configurations or other // interfaces' configurations. @@ -862,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -891,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) { } // Update the NDP configurations on NIC(1) to use DAD. - configs := stack.NDPConfigurations{ + configs := ipv6.NDPConfigurations{ DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(nicID1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(configs) } // Created after updating NIC(1)'s NDP configurations @@ -1024,7 +1011,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) } // raBufWithOpts returns a valid NDP Router Advertisement with options. @@ -1110,14 +1099,15 @@ func TestNoRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1148,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1189,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func(addr tcpip.Address, discovered bool) { @@ -1254,7 +1246,7 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Wait for lladdr2's router invalidation timer to fire. The lifetime + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1271,7 +1263,7 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Wait for lladdr3's router invalidation timer to fire. The lifetime + // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1282,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) { } // TestRouterDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), @@ -1290,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1303,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - if i <= stack.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredDefaultRouters { select { case e := <-ndpDisp.routerC: if diff := checkRouterEvent(e, llAddr, true); diff != "" { @@ -1355,14 +1348,15 @@ func TestNoPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1396,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1442,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1502,7 +1498,7 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Wait for prefix2's most recent invalidation timer plus some buffer to + // Wait for prefix2's most recent invalidation job plus some buffer to // expire. select { case e := <-ndpDisp.prefixC: @@ -1542,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1618,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // TestPrefixDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) } - optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link // prefixes. - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} prefixAddr[7] = byte(i) prefix := tcpip.AddressWithPrefix{ @@ -1662,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < stack.MaxDiscoveredOnLinkPrefixes { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < ipv6.MaxDiscoveredOnLinkPrefixes { select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { @@ -1689,13 +1687,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) AddressWithPrefix: item, } - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false + return containsAddr(list, protocolAddress) } // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. @@ -1719,14 +1711,15 @@ func TestNoAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1752,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr(t *testing.T) { +func TestAutoGenAddr2(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1769,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1879,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := stack.MaxDesyncFactor + savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MaxDesyncFactor = time.Nanosecond prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1934,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2122,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond tests := []struct { name string @@ -2163,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2214,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = 0 + ipv6.MaxDesyncFactor = 0 prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2231,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2297,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2320,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2385,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Wait for all the temporary addresses to get invalidated. @@ -2395,7 +2396,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { for _, addr := range tempAddrs { // Wait for a deprecation then invalidation event, or just an invalidation // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation timers could fire in any + // cases because the deprecation and invalidation jobs could execute in any // order. select { case e := <-ndpDisp.autoGenAddrC: @@ -2432,9 +2433,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } } -// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's -// regeneration timer gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( nicID = 1 regenAfter = 2 * time.Second @@ -2442,17 +2443,17 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2465,16 +2466,17 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2533,7 +2535,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { // // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a timer. + // - this regeneration does not depend on a job. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(tempAddr2, newAddr) @@ -2548,9 +2550,12 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { // as paased. ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) select { case e := <-ndpDisp.autoGenAddrC: @@ -2559,18 +2564,16 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { } // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration timer gets reset. + // RA, the regeneration job gets scheduled again. // // The maximum lifetime is the sum of the minimum lifetimes for temporary // addresses + the time that has already passed since the last address was - // generated so that the regeneration timer is needed to generate the next + // generated so that the regeneration job is needed to generate the next // address. newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) - } + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } @@ -2658,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) s.SetRouteTable([]tcpip.Route{{ @@ -2742,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { ndpDisp.dadC = make(chan ndpDADEvent, 2) ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits ndpConfigs.RetransmitTimer = retransmitTimer - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Do SLAAC for prefix. @@ -2757,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // DAD failure to restart the local generation process. addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] expectAutoGenAddrAsyncEvent(addr, newAddr) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { @@ -2790,20 +2795,22 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2813,7 +2820,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd Gateway: llAddr3, NIC: nicID, }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if useNeighborCache { + s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + } else { + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + } return ndpDisp, e, s } @@ -2887,329 +2898,366 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA // TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when // receiving a PI with 0 preferred lifetime. func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + const nicID = 1 - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - default: - t.Fatal("expected addr auto gen event") - } - } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + }) } - expectPrimaryAddr(addr2) } -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated // when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { +func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate + defer func() { + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved + }() + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - default: - t.Fatal("expected addr auto gen event") - } - } - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). select { case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") } - case <-time.After(defaultAsyncPositiveEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } - } else { - t.Fatalf("got unexpected auto-generated event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } + }) } } @@ -3219,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 const minVLSeconds = 1 savedIL := header.NDPInfiniteLifetime - savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL header.NDPInfiniteLifetime = savedIL }() - stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3268,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3318,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 const newMinVL = 4 - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3410,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { } e := channel.New(10, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3476,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3513,8 +3564,8 @@ func TestAutoGenAddrRemoval(t *testing.T) { } expectAutoGenAddrEvent(addr, invalidatedAddr) - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") @@ -3527,110 +3578,128 @@ func TestAutoGenAddrRemoval(t *testing.T) { func TestAutoGenAddrAfterRemoval(t *testing.T) { const nicID = 1 - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) + + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) + }) } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) } // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that @@ -3643,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3724,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3799,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const lifetimeSeconds = 10 // Needed for the temporary address sub test. - savedMaxDesync := stack.MaxDesyncFactor + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] @@ -3881,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -3906,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, }, @@ -3920,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "Temporary address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -3972,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, }, - SecretKey: secretKey, - }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4002,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) expectDADEvent(t, &ndpDisp, addr.Address, false) @@ -4062,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool subnet tcpip.Subnet triggerSLAACFn func(e *channel.Endpoint) }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -4085,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, AutoGenAddressConflictRetries: maxRetries, @@ -4108,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4141,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4193,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4239,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Simulate a DAD conflict after some time has passed. time.Sleep(failureTimer) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4402,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4452,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4583,7 +4653,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -4637,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func() (bool, ndpRouterEvent) { @@ -4910,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { t.Helper() select { case e := <-ndpDisp.dhcpv6ConfigurationC: @@ -4945,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Even if the first RA reports no DHCPv6 configurations are available, the // dispatcher should get an event. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) // Receiving the same update again should not result in an event to the // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) @@ -4954,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to Managed Address. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to none. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -4974,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // // Note, when the M flag is set, the O flag is redundant. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) expectNoDHCPv6Event() // Even though the DHCPv6 flags are different, the effective configuration is @@ -4987,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -5002,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() } @@ -5140,16 +5212,15 @@ func TestRouterSolicitation(t *testing.T) { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } - checker.IPv6(t, - p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } waitForNothing := func(timeout time.Duration) { @@ -5161,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) { } } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5230,11 +5302,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, @@ -5294,19 +5366,20 @@ func TestStopStartSolicitingRouters(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go new file mode 100644 index 000000000..27e1feec0 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -0,0 +1,333 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const neighborCacheSize = 512 // max entries per interface + +// neighborCache maps IP addresses to link addresses. It uses the Least +// Recently Used (LRU) eviction strategy to implement a bounded cache for +// dynmically acquired entries. It contains the state machine and configuration +// for running Neighbor Unreachability Detection (NUD). +// +// There are two types of entries in the neighbor cache: +// 1. Dynamic entries are discovered automatically by neighbor discovery +// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm +// reachability with the device once the entry's state becomes Stale. +// 2. Static entries are explicitly added by a user and have no expiration. +// Their state is always Static. The amount of static entries stored in the +// cache is unbounded. +// +// neighborCache implements NUDHandler. +type neighborCache struct { + nic *NIC + state *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList + + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } +} + +var _ NUDHandler = (*neighborCache)(nil) + +// getOrCreateEntry retrieves a cache entry associated with addr. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[remoteAddr]; ok { + entry.mu.RLock() + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.lru.PushFront(entry) + } + entry.mu.RUnlock() + return entry + } + + // The entry that needs to be created must be dynamic since all static + // entries are directly added to the cache via addStaticEntry. + entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + if n.dynamic.count == neighborCacheSize { + e := n.dynamic.lru.Back() + e.mu.Lock() + + delete(n.cache, e.neigh.Addr) + n.dynamic.lru.Remove(e) + n.dynamic.count-- + + e.dispatchRemoveEventLocked() + e.setStateLocked(Unknown) + e.notifyWakersLocked() + e.mu.Unlock() + } + n.cache[remoteAddr] = entry + n.dynamic.lru.PushFront(entry) + n.dynamic.count++ + return entry +} + +// entry looks up the neighbor cache for translating address to link address +// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there +// is a LinkAddressResolver registered with the network protocol, the cache +// attempts to resolve the address and returns ErrWouldBlock. If a Waker is +// provided, it will be notified when address resolution is complete (success +// or not). +// +// If address resolution is required, ErrNoLinkAddress and a notification +// channel is returned for the top level caller to block. Channel is closed +// once address resolution is complete (success or not). +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } + return e, nil, nil + } + + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + defer entry.mu.Unlock() + + switch s := entry.neigh.State; s { + case Reachable, Static: + return entry.neigh, nil, nil + + case Unknown, Incomplete, Stale, Delay, Probe: + entry.addWakerLocked(w) + + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.neigh, nil, tcpip.ErrNoLinkAddress + } + entry.done = make(chan struct{}) + } + + entry.handlePacketQueuedLocked() + return entry.neigh, entry.done, tcpip.ErrWouldBlock + + case Failed: + return entry.neigh, nil, tcpip.ErrNoLinkAddress + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", s)) + } +} + +// removeWaker removes a waker that has been added when link resolution for +// addr was requested. +func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { + n.mu.Lock() + if entry, ok := n.cache[addr]; ok { + delete(entry.wakers, waker) + } + n.mu.Unlock() +} + +// entries returns all entries in the neighbor cache. +func (n *neighborCache) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(n.cache)) + n.mu.RLock() + for _, entry := range n.cache { + entry.mu.RLock() + entries = append(entries, entry.neigh) + entry.mu.RUnlock() + } + n.mu.RUnlock() + return entries +} + +// addStaticEntry adds a static entry to the neighbor cache, mapping an IP +// address to a link address. If a dynamic entry exists in the neighbor cache +// with the same address, it will be replaced with this static entry. If a +// static entry exists with the same address but different link address, it +// will be updated with the new link address. If a static entry exists with the +// same address and link address, nothing will happen. +func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[addr]; ok { + entry.mu.Lock() + if entry.neigh.State != Static { + // Dynamic entry found with the same address. + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } else if entry.neigh.LinkAddr == linkAddr { + // Static entry found with the same address and link address. + entry.mu.Unlock() + return + } else { + // Static entry found with the same address but different link address. + entry.neigh.LinkAddr = linkAddr + entry.dispatchChangeEventLocked(entry.neigh.State) + entry.mu.Unlock() + return + } + + // Notify that resolution has been interrupted, just in case the entry was + // in the Incomplete or Probe state. + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = entry +} + +// removeEntryLocked removes the specified entry from the neighbor cache. +func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + if entry.neigh.State != Failed { + entry.dispatchRemoveEventLocked() + } + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + + delete(n.cache, entry.neigh.Addr) +} + +// removeEntry removes a dynamic or static entry by address from the neighbor +// cache. Returns true if the entry was found and deleted. +func (n *neighborCache) removeEntry(addr tcpip.Address) bool { + n.mu.Lock() + defer n.mu.Unlock() + + entry, ok := n.cache[addr] + if !ok { + return false + } + + entry.mu.Lock() + defer entry.mu.Unlock() + + n.removeEntryLocked(entry) + return true +} + +// clear removes all dynamic and static entries from the neighbor cache. +func (n *neighborCache) clear() { + n.mu.Lock() + defer n.mu.Unlock() + + for _, entry := range n.cache { + entry.mu.Lock() + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + n.dynamic.lru = neighborEntryList{} + n.cache = make(map[tcpip.Address]*neighborEntry) + n.dynamic.count = 0 +} + +// config returns the NUD configuration. +func (n *neighborCache) config() NUDConfigurations { + return n.state.Config() +} + +// setConfig changes the NUD configuration. +// +// If config contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *neighborCache) setConfig(config NUDConfigurations) { + config.resetInvalidFields() + n.state.SetConfig(config) +} + +// HandleProbe implements NUDHandler.HandleProbe by following the logic defined +// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled +// by the caller. +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + entry.handleProbeLocked(remoteLinkAddr) + entry.mu.Unlock() +} + +// HandleConfirmation implements NUDHandler.HandleConfirmation by following the +// logic defined in RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce cryptographically +// generated addresses, as defined in RFC 3972, Cryptographically Generated +// Addresses (CGA). This ensures that the claimed source of an NDP message is +// the owner of the claimed address. +func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleConfirmationLocked(linkAddr, flags) + entry.mu.Unlock() + } + // The confirmation SHOULD be silently discarded if the recipient did not + // initiate any communication with the target. This is indicated if there is + // no matching entry for the remote address. +} + +// HandleUpperLevelConfirmation implements +// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in +// RFC 4861 section 7.3.1. +func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleUpperLevelConfirmationLocked() + entry.mu.Unlock() + } +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go new file mode 100644 index 000000000..a0b7da5cd --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -0,0 +1,1727 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/rand" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" +) + +const ( + // entryStoreSize is the default number of entries that will be generated and + // added to the entry store. This number needs to be larger than the size of + // the neighbor cache to give ample opportunity for verifying behavior during + // cache overflows. Four times the size of the neighbor cache allows for + // three complete cache overflows. + entryStoreSize = 4 * neighborCacheSize + + // typicalLatency is the typical latency for an ARP or NDP packet to travel + // to a router and back. + typicalLatency = time.Millisecond + + // testEntryBroadcastAddr is a special address that indicates a packet should + // be sent to all nodes. + testEntryBroadcastAddr = tcpip.Address("broadcast") + + // testEntryLocalAddr is the source address of neighbor probes. + testEntryLocalAddr = tcpip.Address("local_addr") + + // testEntryBroadcastLinkAddr is a special link address sent back to + // multicast neighbor probes. + testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") + + // infiniteDuration indicates that a task will not occur in our lifetime. + infiniteDuration = time.Duration(math.MaxInt64) +) + +// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor +// entries. The UpdatedAt field is ignored due to a lack of a deterministic +// method to predict the time that an event will be dispatched. +func entryDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + } +} + +// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to +// sort slices of entries for cases where ordering must be ignored. +func entryDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { + config.resetInvalidFields() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return &neighborCache{ + nic: &NIC{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + }, + id: 1, + }, + state: NewNUDState(config, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } +} + +// testEntryStore contains a set of IP to NeighborEntry mappings. +type testEntryStore struct { + mu sync.RWMutex + entriesMap map[tcpip.Address]NeighborEntry +} + +func toAddress(i int) tcpip.Address { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint16(i)) + return tcpip.Address(buf.String()) +} + +func toLinkAddress(i int) tcpip.LinkAddress { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint32(i)) + return tcpip.LinkAddress(buf.String()) +} + +// newTestEntryStore returns a testEntryStore pre-populated with entries. +func newTestEntryStore() *testEntryStore { + store := &testEntryStore{ + entriesMap: make(map[tcpip.Address]NeighborEntry), + } + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + linkAddr := toLinkAddress(i) + + store.entriesMap[addr] = NeighborEntry{ + Addr: addr, + LocalAddr: testEntryLocalAddr, + LinkAddr: linkAddr, + } + } + return store +} + +// size returns the number of entries in the store. +func (s *testEntryStore) size() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entriesMap) +} + +// entry returns the entry at index i. Returns an empty entry and false if i is +// out of bounds. +func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { + return s.entryByAddr(toAddress(i)) +} + +// entryByAddr returns the entry matching addr for situations when the index is +// not available. Returns an empty entry and false if no entries match addr. +func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + entry, ok := s.entriesMap[addr] + return entry, ok +} + +// entries returns all entries in the store. +func (s *testEntryStore) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(s.entriesMap)) + s.mu.RLock() + defer s.mu.RUnlock() + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + if entry, ok := s.entriesMap[addr]; ok { + entries = append(entries, entry) + } + } + return entries +} + +// set modifies the link addresses of an entry. +func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { + addr := toAddress(i) + s.mu.Lock() + defer s.mu.Unlock() + if entry, ok := s.entriesMap[addr]; ok { + entry.LinkAddr = linkAddr + s.entriesMap[addr] = entry + } +} + +// testNeighborResolver implements LinkAddressResolver to emulate sending a +// neighbor probe. +type testNeighborResolver struct { + clock tcpip.Clock + neigh *neighborCache + entries *testEntryStore + delay time.Duration + onLinkAddressRequest func() +} + +var _ LinkAddressResolver = (*testNeighborResolver)(nil) + +func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(addr) + }) + + // Execute post address resolution action, if available. + if f := r.onLinkAddressRequest; f != nil { + f() + } + return nil +} + +// fakeRequest emulates handling a response for a link address request. +func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { + if entry, ok := r.entries.entryByAddr(addr); ok { + r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } +} + +func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == testEntryBroadcastAddr { + return testEntryBroadcastLinkAddr, true + } + return "", false +} + +func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 0 +} + +type entryEvent struct { + nicID tcpip.NICID + address tcpip.Address + linkAddr tcpip.LinkAddress + state NeighborState +} + +func TestNeighborCacheGetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheSetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + neigh.setConfig(c) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.Advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveEntry(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.Advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } +} + +type testContext struct { + clock *faketime.ManualClock + neigh *neighborCache + store *testEntryStore + linkRes *testNeighborResolver + nudDisp *testNUDDispatcher +} + +func newTestContext(c NUDConfigurations) testContext { + nudDisp := &testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + return testContext{ + clock: clock, + neigh: neigh, + store: store, + linkRes: linkRes, + nudDisp: nudDisp, + } +} + +type overflowOptions struct { + startAtEntryIndex int + wantStaticEntries []NeighborEntry +} + +func (c *testContext) overflowCache(opts overflowOptions) error { + // Fill the neighbor cache to capacity to verify the LRU eviction strategy is + // working properly after the entry removal. + for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + // Add a new entry + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.Advance(c.neigh.config().RetransmitTimer) + + var wantEvents []testEntryEventInfo + + // When beyond the full capacity, the cache will evict an entry as per the + // LRU eviction strategy. Note that the number of static entries should not + // affect the total number of dynamic entries that can be added. + if i >= neighborCacheSize+opts.startAtEntryIndex { + removedEntry, ok := c.store.entry(i - neighborCacheSize) + if !ok { + return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + } + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, testEntryEventInfo{ + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }) + + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the most recent entries. The order of entries reported + // by entries() is undeterministic, so entries have to be sorted before + // comparison. + wantUnsortedEntries := opts.wantStaticEntries + for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + return nil +} + +// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy +// respects the dynamic entry count. +func TestNeighborCacheOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when an entry is removed. +func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.Advance(c.neigh.config().RetransmitTimer) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the entry + c.neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that +// adding a duplicate static entry with the same link address does not dispatch +// any events. +func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that +// adding a duplicate static entry with a different link address dispatches a +// change event. +func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a duplicate entry with a different link address + staticLinkAddr += "duplicate" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when a static entry is +// added then removed. In this case, the dynamic entry count shouldn't have +// been touched. +func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.removeEntry(entry.Addr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU +// cache eviction strategy keeps count of the dynamic entry count when an entry +// is overwritten by a static entry. Static entries should not count towards +// the size of the LRU cache. +func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.Advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Override the entry with a static one using the same address + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: staticLinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheNotifiesWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + clock.Advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + id, ok := s.Fetch(false /* block */) + if !ok { + t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + if id != wakerID { + t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + + // Remove the waker before the neighbor cache has the opportunity to send a + // notification. + neigh.removeWaker(entry.Addr, &w) + clock.Advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Errorf("unexpected notification from waker with id %d", id) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if err != nil { + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheClear(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add a dynamic entry. + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a static entry. + neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Clear shoud remove both dynamic and static entries. + neigh.clear() + + // Remove events dispatched from clear() have no deterministic order so they + // need to be sorted beforehand. + wantUnsortedEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction +// strategy keeps count of the dynamic entry count when all entries are +// cleared. +func TestNeighborCacheClearThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.Advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Clear the cache. + c.neigh.clear() + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + frequentlyUsedEntry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + + // The following logic is very similar to overflowCache, but + // periodically refreshes the frequently used entry. + + // Fill the neighbor cache to capacity + for i := 0; i < neighborCacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Keep adding more entries + for i := neighborCacheSize; i < store.size(); i++ { + // Periodically refresh the frequently used entry + if i%(neighborCacheSize/2) == 0 { + _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + } + } + + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // An entry should have been removed, as per the LRU eviction strategy + removedEntry, ok := store.entry(i - neighborCacheSize + 1) + if !ok { + t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the frequently used entry and the most recent entries. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + wantUnsortedEntries := []NeighborEntry{ + { + Addr: frequentlyUsedEntry.Addr, + LocalAddr: frequentlyUsedEntry.LocalAddr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, + }, + } + + for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheConcurrent(t *testing.T) { + const concurrentProcesses = 16 + + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + storeEntries := store.entries() + for _, entry := range storeEntries { + var wg sync.WaitGroup + for r := 0; r < concurrentProcesses; r++ { + wg.Add(1) + go func(entry NeighborEntry) { + defer wg.Done() + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + } + }(entry) + } + + // Wait for all gorountines to send a request + wg.Wait() + + // Process all the requests for a single entry concurrently + clock.Advance(typicalLatency) + } + + // All goroutines add in the same order and add more values than can fit in + // the cache. Our eviction strategy requires that the last entries are + // present, up to the size of the neighbor cache, and the rest are missing. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + var wantUnsortedEntries []NeighborEntry + for i := store.size() - neighborCacheSize; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Errorf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheReplace(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add an entry + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // Verify the entry exists + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } + + // Notify of a link address change + var updatedLinkAddr tcpip.LinkAddress + { + entry, ok := store.entry(1) + if !ok { + t.Fatalf("store.entry(1) not found") + } + updatedLinkAddr = entry.LinkAddr + } + store.set(0, updatedLinkAddr) + neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + + // Requesting the entry again should start address resolution + { + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.Advance(config.DelayFirstProbeTime + typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + } + + // Verify the entry's new link address + { + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + clock.Advance(typicalLatency) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want = NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + } +} + +func TestNeighborCacheResolutionFailed(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + + var requestCount uint32 + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + onLinkAddressRequest: func() { + atomic.AddUint32(&requestCount, 1) + }, + } + + // First, sanity check that resolution is working + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + + // Verify that address resolution for an unknown address returns ErrNoLinkAddress + before := atomic.LoadUint32(&requestCount) + + entry.Addr += "2" + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + maxAttempts := neigh.config().MaxUnicastProbes + if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } +} + +// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes +// probes and not retrieving a confirmation before the duration defined by +// MaxMulticastProbes * RetransmitTimer. +func TestNeighborCacheResolutionTimeout(t *testing.T) { + config := DefaultNUDConfigurations() + config.RetransmitTimer = time.Millisecond // small enough to cause timeout + + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: time.Minute, // large enough to cause timeout + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } +} + +// TestNeighborCacheStaticResolution checks that static link addresses are +// resolved immediately and don't send resolution requests. +func TestNeighborCacheStaticResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + } + want := NeighborEntry{ + Addr: testEntryBroadcastAddr, + LocalAddr: testEntryLocalAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + } +} + +func BenchmarkCacheClear(b *testing.B) { + b.StopTimer() + config := DefaultNUDConfigurations() + clock := &tcpip.StdClock{} + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: 0, + } + + // Clear for every possible size of the cache + for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + // Fill the neighbor cache to capacity. + for i := 0; i < cacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + b.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh != nil { + <-doneCh + } + } + + b.StartTimer() + neigh.clear() + b.StopTimer() + } +} diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go new file mode 100644 index 000000000..9a72bec79 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -0,0 +1,490 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// NeighborEntry describes a neighboring device in the local network. +type NeighborEntry struct { + Addr tcpip.Address + LocalAddr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +// NeighborState defines the state of a NeighborEntry within the Neighbor +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +type NeighborState uint8 + +const ( + // Unknown means reachability has not been verified yet. This is the initial + // state of entries that have been created automatically by the Neighbor + // Unreachability Detection state machine. + Unknown NeighborState = iota + // Incomplete means that there is an outstanding request to resolve the + // address. + Incomplete + // Reachable means the path to the neighbor is functioning properly for both + // receive and transmit paths. + Reachable + // Stale means reachability to the neighbor is unknown, but packets are still + // able to be transmitted to the possibly stale link address. + Stale + // Delay means reachability to the neighbor is unknown and pending + // confirmation from an upper-level protocol like TCP, but packets are still + // able to be transmitted to the possibly stale link address. + Delay + // Probe means a reachability confirmation is actively being sought by + // periodically retransmitting reachability probes until a reachability + // confirmation is received, or until the max amount of probes has been sent. + Probe + // Static describes entries that have been explicitly added by the user. They + // do not expire and are not deleted until explicitly removed. + Static + // Failed means traffic should not be sent to this neighbor since attempts of + // reachability have returned inconclusive. + Failed +) + +// neighborEntry implements a neighbor entry's individual node behavior, as per +// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in +// parallel with the sending of packets to a neighbor, necessitating the +// entry's lock to be acquired for all operations. +type neighborEntry struct { + neighborEntryEntry + + nic *NIC + + // linkRes provides the functionality to send reachability probes, used in + // Neighbor Unreachability Detection. + linkRes LinkAddressResolver + + // nudState points to the Neighbor Unreachability Detection configuration. + nudState *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + neigh NeighborEntry + + // wakers is a set of waiters for address resolution result. Anytime state + // transitions out of incomplete these waiters are notified. It is nil iff + // address resolution is ongoing and no clients are waiting for the result. + wakers map[*sleep.Waker]struct{} + + // done is used to allow callers to wait on address resolution. It is nil + // iff nudState is not Reachable and address resolution is not yet in + // progress. + done chan struct{} + + isRouter bool + job *tcpip.Job +} + +// newNeighborEntry creates a neighbor cache entry starting at the default +// state, Unknown. Transition out of Unknown by calling either +// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created +// neighborEntry. +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { + return &neighborEntry{ + nic: nic, + linkRes: linkRes, + nudState: nudState, + neigh: NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + State: Unknown, + }, + } +} + +// newStaticNeighborEntry creates a neighbor cache entry starting at the Static +// state. The entry can only transition out of Static by directly calling +// `setStateLocked`. +func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + if nic.stack.nudDisp != nil { + nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + } + return &neighborEntry{ + nic: nic, + nudState: state, + neigh: NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + }, + } +} + +// addWaker adds w to the list of wakers waiting for address resolution. +// Assumes the entry has already been appropriately locked. +func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { + if w == nil { + return + } + if e.wakers == nil { + e.wakers = make(map[*sleep.Waker]struct{}) + } + e.wakers[w] = struct{}{} +} + +// notifyWakersLocked notifies those waiting for address resolution, whether it +// succeeded or failed. Assumes the entry has already been appropriately locked. +func (e *neighborEntry) notifyWakersLocked() { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has +// been added. +func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry +// has changed state or link-layer address. +func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry +// has been removed. +func (e *neighborEntry) dispatchRemoveEventLocked() { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + } +} + +// setStateLocked transitions the entry to the specified state immediately. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +// +// e.mu MUST be locked. +func (e *neighborEntry) setStateLocked(next NeighborState) { + // Cancel the previously scheduled action, if there is one. Entries in + // Unknown, Stale, or Static state do not have scheduled actions. + if timer := e.job; timer != nil { + timer.Cancel() + } + + prev := e.neigh.State + e.neigh.State = next + e.neigh.UpdatedAt = time.Now() + config := e.nudState.Config() + + switch next { + case Incomplete: + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() + + case Reachable: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + }) + e.job.Schedule(e.nudState.ReachableTime()) + + case Delay: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Probe) + e.setStateLocked(Probe) + }) + e.job.Schedule(config.DelayFirstProbeTime) + + case Probe: + var retryCounter uint32 + var sendUnicastProbe func() + + sendUnicastProbe = func() { + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendUnicastProbe() + + case Failed: + e.notifyWakersLocked() + e.job = e.nic.stack.newJob(&e.mu, func() { + e.nic.neigh.removeEntryLocked(e) + }) + e.job.Schedule(config.UnreachableTime) + + case Unknown, Stale, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next)) + } +} + +// handlePacketQueuedLocked advances the state machine according to a packet +// being queued for outgoing transmission. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +func (e *neighborEntry) handlePacketQueuedLocked() { + switch e.neigh.State { + case Unknown: + e.dispatchAddEventLocked(Incomplete) + e.setStateLocked(Incomplete) + + case Stale: + e.dispatchChangeEventLocked(Delay) + e.setStateLocked(Delay) + + case Incomplete, Reachable, Delay, Probe, Static, Failed: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or +// Neighbor Solicitation for ARP or NDP, respectively). +// +// Follows the logic defined in RFC 4861 section 7.2.3. +func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { + // Probes MUST be silently discarded if the target address is tentative, does + // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These + // checks MUST be done by the NetworkEndpoint. + + switch e.neigh.State { + case Unknown, Incomplete, Failed: + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchAddEventLocked(Stale) + e.setStateLocked(Stale) + e.notifyWakersLocked() + + case Reachable, Delay, Probe: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + + case Stale: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + } + + case Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleConfirmationLocked processes an incoming neighbor confirmation +// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively). +// +// Follows the state machine defined by RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce Cryptographically +// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the +// claimed source of an NDP message is the owner of the claimed address. +func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + switch e.neigh.State { + case Incomplete: + if len(linkAddr) == 0 { + // "If the link layer has addresses and no Target Link-Layer Address + // option is included, the receiving node SHOULD silently discard the + // received advertisement." - RFC 4861 section 7.2.5 + break + } + + e.neigh.LinkAddr = linkAddr + if flags.Solicited { + e.dispatchChangeEventLocked(Reachable) + e.setStateLocked(Reachable) + } else { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + e.isRouter = flags.IsRouter + e.notifyWakersLocked() + + // "Note that the Override flag is ignored if the entry is in the + // INCOMPLETE state." - RFC 4861 section 7.2.5 + + case Reachable, Stale, Delay, Probe: + sameLinkAddr := e.neigh.LinkAddr == linkAddr + + if !sameLinkAddr { + if !flags.Override { + if e.neigh.State == Reachable { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + break + } + + e.neigh.LinkAddr = linkAddr + + if !flags.Solicited { + if e.neigh.State != Stale { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } else { + // Notify the LinkAddr change, even though NUD state hasn't changed. + e.dispatchChangeEventLocked(e.neigh.State) + } + break + } + } + + if flags.Solicited && (flags.Override || sameLinkAddr) { + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + } + // Set state to Reachable again to refresh timers. + e.setStateLocked(Reachable) + e.notifyWakersLocked() + } + + if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { + // "In those cases where the IsRouter flag changes from TRUE to FALSE as + // a result of this update, the node MUST remove that router from the + // Default Router List and update the Destination Cache entries for all + // destinations using that neighbor as a router as specified in Section + // 7.3.3. This is needed to detect when a node that is used as a router + // stops forwarding packets due to being configured as a host." + // - RFC 4861 section 7.2.5 + // + // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 + // here. + ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + if !ok { + panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) + } + + if ndpEP, ok := ep.(NDPEndpoint); ok { + ndpEP.InvalidateDefaultRouter(e.neigh.Addr) + } + } + e.isRouter = flags.IsRouter + + case Unknown, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol +// (e.g. TCP acknowledgements) reachability confirmation. +func (e *neighborEntry) handleUpperLevelConfirmationLocked() { + switch e.neigh.State { + case Reachable, Stale, Delay, Probe: + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + // Set state to Reachable again to refresh timers. + } + e.setStateLocked(Reachable) + + case Unknown, Incomplete, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go new file mode 100644 index 000000000..a265fff0a --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -0,0 +1,2869 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "math" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + + entryTestNICID tcpip.NICID = 1 + entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") + entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") + + // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match the + // MTU of loopback interfaces on Linux systems. + entryTestNetDefaultMTU = 65536 +) + +// eventDiffOpts are the options passed to cmp.Diff to compare entry events. +// The UpdatedAt field is ignored due to a lack of a deterministic method to +// predict the time that an event will be dispatched. +func eventDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + } +} + +// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to +// sort slices of events for cases where ordering must be ignored. +func eventDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +// The following unit tests exercise every state transition and verify its +// behavior with RFC 4681. +// +// | From | To | Cause | Action | Event | +// | ========== | ========== | ========================================== | =============== | ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | +// | Unknown | Stale | Probe w/ unknown address | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | +// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | +// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | +// | Reachable | Stale | Reachable timer expired | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | +// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | +// | Stale | Delay | Packet sent | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | Changed | +// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | Changed | +// | Delay | Probe | Delay timer expired | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | Changed | +// | Probe | Probe | Retransmit timer expired | Send probe | Changed | +// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | | Unreachability timer expired | Delete entry | | + +type testEntryEventType uint8 + +const ( + entryTestAdded testEntryEventType = iota + entryTestChanged + entryTestRemoved +) + +func (t testEntryEventType) String() string { + switch t { + case entryTestAdded: + return "add" + case entryTestChanged: + return "change" + case entryTestRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Fields are exported for use with cmp.Diff. +type testEntryEventInfo struct { + EventType testEntryEventType + NICID tcpip.NICID + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +func (e testEntryEventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) +} + +// testNUDDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type testNUDDispatcher struct { + mu sync.Mutex + events []testEntryEventInfo +} + +var _ NUDDispatcher = (*testNUDDispatcher)(nil) + +func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { + d.mu.Lock() + defer d.mu.Unlock() + d.events = append(d.events, e) +} + +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestAdded, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestChanged, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +type entryTestLinkResolver struct { + mu sync.Mutex + probes []entryTestProbeInfo +} + +var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) + +type entryTestProbeInfo struct { + RemoteAddress tcpip.Address + RemoteLinkAddress tcpip.LinkAddress + LocalAddress tcpip.Address +} + +func (p entryTestProbeInfo) String() string { + return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) +} + +// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts +// to the local network if linkAddr is the zero value. +func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + p := entryTestProbeInfo{ + RemoteAddress: addr, + RemoteLinkAddress: linkAddr, + LocalAddress: localAddr, + } + r.mu.Lock() + defer r.mu.Unlock() + r.probes = append(r.probes, p) + return nil +} + +// ResolveStaticAddress attempts to resolve address without sending requests. +// It either resolves the name immediately or returns the empty LinkAddress. +func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +// LinkAddressProtocol returns the network protocol of the addresses this +// resolver can resolve. +func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return entryTestNetNumber +} + +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { + clock := faketime.NewManualClock() + disp := testNUDDispatcher{} + nic := NIC{ + id: entryTestNICID, + linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + stack: &Stack{ + clock: clock, + nudDisp: &disp, + }, + } + nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ + header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + nudState := NewNUDState(c, rng) + linkRes := entryTestLinkResolver{} + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + + // Stub out the neighbor cache to verify deletion from the cache. + nic.neigh = &neighborCache{ + nic: &nic, + state: nudState, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + nic.neigh.cache[entryTestAddr1] = entry + + return entry, &disp, &linkRes, clock +} + +// TestEntryInitiallyUnknown verifies that the state of a newly created +// neighborEntry is Unknown. +func TestEntryInitiallyUnknown(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.RetransmitTimer) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(time.Hour) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToIncomplete(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + { + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +func TestEntryUnknownToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + updatedAt := e.neigh.UpdatedAt + e.mu.Unlock() + + clock.Advance(c.RetransmitTimer) + + // UpdatedAt should remain the same during address resolution. + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.RetransmitTimer) + + // UpdatedAt should change after failing address resolution. Timing out after + // sending the last probe transitions the entry to Failed. + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + clock.Advance(c.RetransmitTimer) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + } + e.mu.Unlock() +} + +func TestEntryIncompleteToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryAddsAndClearsWakers verifies that wakers are added when +// addWakerLocked is called and cleared when address resolution finishes. In +// this case, address resolution will finish when transitioning from Incomplete +// to Reachable. +func TestEntryAddsAndClearsWakers(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got := e.wakers; got != nil { + t.Errorf("got e.wakers = %v, want = nil", got) + } + e.addWakerLocked(&w) + if got, want := w.IsAsserted(), false; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + if e.wakers == nil { + t.Error("expected e.wakers to be non-nil") + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.wakers != nil { + t.Errorf("got e.wakers = %v, want = nil", e.wakers) + } + if got, want := w.IsAsserted(), true; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + linkRes.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.Advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +type testLocker struct{} + +var _ sync.Locker = (*testLocker)(nil) + +func (*testLocker) Lock() {} +func (*testLocker) Unlock() {} + +func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.isRouter, false; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + if ipv6EP.invalidatedRtr != e.neigh.Addr { + t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryReachableToStaleWhenTimeout(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToDelay(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleUpperLevelConfirmationLocked() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToProbe(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario: +// 1. Probe is received +// 2. Entry is created in Stale +// 3. Packet is queued on the entry +// 4. Entry transitions to Delay then Probe +// 5. Probe is sent +func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Probe to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // Probe caused by the Delay-to-Probe transition + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.Advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryFailedGetsDeleted(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime + clock.Advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + // Verify the cache no longer contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { + t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) + } +} diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go new file mode 100644 index 000000000..aa7311ec6 --- /dev/null +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NeighborState"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Unknown-0] + _ = x[Incomplete-1] + _ = x[Reachable-2] + _ = x[Stale-3] + _ = x[Delay-4] + _ = x[Probe-5] + _ = x[Static-6] + _ = x[Failed-7] +} + +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" + +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} + +func (i NeighborState) String() string { + if i >= NeighborState(len(_NeighborState_index)-1) { + return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]] +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index afb7dfeaf..6cf54cc89 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,24 +16,18 @@ package stack import ( "fmt" + "math/rand" "reflect" - "sort" - "strings" "sync/atomic" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) -var ipv4BroadcastAddr = tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 8 * header.IPv4AddressSize, - }, -} +var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. @@ -45,20 +39,24 @@ type NIC struct { context NICContext stats NICStats + neigh *neighborCache + + // The network endpoints themselves may be modified by calling the interface's + // methods, but the map reference and entries must be constant. + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 mu struct { sync.RWMutex - enabled bool - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]uint32 + spoofing bool + promiscuous bool // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - ndp ndpState } } @@ -82,25 +80,6 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } -// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior. -type PrimaryEndpointBehavior int - -const ( - // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. - CanBePrimaryEndpoint PrimaryEndpointBehavior = iota - - // FirstPrimaryEndpoint indicates the endpoint should be the first - // primary endpoint considered. If there are multiple endpoints with - // this behavior, the most recently-added one will be first. - FirstPrimaryEndpoint - - // NeverPrimaryEndpoint indicates the endpoint should never be a - // primary endpoint. - NeverPrimaryEndpoint -) - // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -112,33 +91,43 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - context: ctx, - stats: makeNICStats(), + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) - nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) - nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), + + // Check for Neighbor Unreachability Detection support. + var nud NUDHandler + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { + rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) + nic.neigh = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + + // An interface value that holds a nil pointer but non-nil type is not the + // same as the nil interface. Because of this, nud must only be assignd if + // nic.neigh is non-nil since a nil reference to a neighborCache is not + // valid. + // + // See https://golang.org/doc/faq#nil_error for more information. + nud = nic.neigh } - nic.mu.ndp.initializeTempAddrState() - // Register supported packet endpoint protocols. + // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { nic.mu.packetEPs[netProto] = []PacketEndpoint{} } for _, netProto := range stack.networkProtocols { - nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} + netNum := netProto.Number() + nic.mu.packetEPs[netNum] = nil + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } nic.linkEP.Attach(nic) @@ -146,29 +135,32 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -// enabled returns true if n is enabled. -func (n *NIC) enabled() bool { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - return enabled +func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { + return n.networkEndpoints[proto] } -// disable disables n. +// Enabled implements NetworkInterface. +func (n *NIC) Enabled() bool { + return atomic.LoadUint32(&n.enabled) == 1 +} + +// setEnabled sets the enabled status for the NIC. // -// It undoes the work done by enable. -func (n *NIC) disable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if !enabled { - return nil +// Returns true if the enabled status was updated. +func (n *NIC) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&n.enabled, 1) == 0 } + return atomic.SwapUint32(&n.enabled, 0) == 1 +} +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() { n.mu.Lock() - err := n.disableLocked() + n.disableLocked() n.mu.Unlock() - return err } // disableLocked disables n. @@ -176,43 +168,19 @@ func (n *NIC) disable() *tcpip.Error { // It undoes the work done by enable. // // n MUST be locked. -func (n *NIC) disableLocked() *tcpip.Error { - if !n.mu.enabled { - return nil +func (n *NIC) disableLocked() { + if !n.setEnabled(false) { + return } - // TODO(b/147015577): Should Routes that are currently bound to n be + // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be // invalidated? Currently, Routes will continue to work when a NIC is enabled // again, and applications may not know that the underlying NIC was ever // disabled. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - n.mu.ndp.stopSolicitingRouters() - n.mu.ndp.cleanupState(false /* hostOnly */) - - // Stop DAD for all the unicast IPv6 endpoints that are in the - // permanentTentative state. - for _, r := range n.mu.endpoints { - if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } - } - - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - } - - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - // The address may have already been removed. - if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } + for _, ep := range n.networkEndpoints { + ep.Disable() } - - n.mu.enabled = false - return nil } // enable enables n. @@ -222,150 +190,38 @@ func (n *NIC) disableLocked() *tcpip.Error { // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if enabled { - return nil - } - n.mu.Lock() defer n.mu.Unlock() - if n.mu.enabled { - return nil - } - - n.mu.enabled = true - - // Create an endpoint to receive broadcast packets on this interface. - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } - } - - // Join the IPv6 All-Nodes Multicast group if the stack is configured to - // use IPv6. This is required to ensure that this node properly receives - // and responds to the various NDP messages that are destined to the - // all-nodes multicast address. An example is the Neighbor Advertisement - // when we perform Duplicate Address Detection, or Router Advertisement - // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 - // section 4.2 for more information. - // - // Also auto-generate an IPv6 link-local address based on the NIC's - // link address if it is configured to do so. Note, each interface is - // required to have IPv6 link-local unicast address, as per RFC 4291 - // section 2.1. - _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] - if !ok { + if !n.setEnabled(true) { return nil } - // Join the All-Nodes multicast group before starting DAD as responses to DAD - // (NDP NS) messages may be sent to the All-Nodes multicast group if the - // source address of the NDP NS is the unspecified address, as per RFC 4861 - // section 7.2.4. - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - - // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent - // state. - // - // Addresses may have aleady completed DAD but in the time since the NIC was - // last enabled, other devices may have acquired the same addresses. - for _, r := range n.mu.endpoints { - addr := r.ep.ID().LocalAddress - if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { - continue - } - - r.setKind(permanentTentative) - if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + for _, ep := range n.networkEndpoints { + if err := ep.Enable(); err != nil { return err } } - // Do not auto-generate an IPv6 link-local address for loopback devices. - if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { - // The valid and preferred lifetime is infinite for the auto-generated - // link-local address. - n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) - } - - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyways. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !n.stack.forwarding { - n.mu.ndp.startSolicitingRouters() - } - return nil } -// remove detaches NIC from the link endpoint, and marks existing referenced -// network endpoints expired. This guarantees no packets between this NIC and -// the network stack. +// remove detaches NIC from the link endpoint and releases network endpoint +// resources. This guarantees no packets between this NIC and the network +// stack. func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() n.disableLocked() - // TODO(b/151378115): come up with a better way to pick an error than the - // first one. - var err *tcpip.Error - - // Forcefully leave multicast groups. - for nid := range n.mu.mcastJoins { - if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { - err = tempErr - } - } - - // Remove permanent and permanentTentative addresses, so no packet goes out. - for nid, ref := range n.mu.endpoints { - switch ref.getKind() { - case permanentTentative, permanent: - if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { - err = tempErr - } - } + for _, ep := range n.networkEndpoints { + ep.Close() } // Detach from link endpoint, so no packet comes in. n.linkEP.Attach(nil) - - return err -} - -// becomeIPv6Router transitions n into an IPv6 router. -// -// When transitioning into an IPv6 router, host-only state (NDP discovered -// routers, discovered on-link prefixes, and auto-generated addresses) will -// be cleaned up/invalidated and NDP router solicitations will be stopped. -func (n *NIC) becomeIPv6Router() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.cleanupState(true /* hostOnly */) - n.mu.ndp.stopSolicitingRouters() -} - -// becomeIPv6Host transitions n into an IPv6 host. -// -// When transitioning into an IPv6 host, NDP router solicitations will be -// started. -func (n *NIC) becomeIPv6Host() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.startSolicitingRouters() + return nil } // setPromiscuousMode enables or disables promiscuous mode. @@ -382,7 +238,8 @@ func (n *NIC) isPromiscuousMode() bool { return rv } -func (n *NIC) isLoopback() bool { +// IsLoopback implements NetworkInterface. +func (n *NIC) IsLoopback() bool { return n.linkEP.Capabilities()&CapabilityLoopback != 0 } @@ -393,213 +250,53 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists for the given protocol and remoteAddr. If no non-deprecated -// endpoint exists, the first deprecated endpoint will be returned. -// -// If an IPv6 primary endpoint is requested, Source Address Selection (as -// defined by RFC 6724 section 5) will be performed. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { - if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { - return n.primaryIPv6Endpoint(remoteAddr) - } - +// primaryAddress returns an address that can be used to communicate with +// remoteAddr. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { n.mu.RLock() - defer n.mu.RUnlock() - - var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.mu.primary[protocol] { - if !r.isValidForOutgoingRLocked() { - continue - } - - if !r.deprecated { - if r.tryIncRef() { - // r is not deprecated, so return it immediately. - // - // If we kept track of a deprecated endpoint, decrement its reference - // count since it was incremented when we decided to keep track of it. - if deprecatedEndpoint != nil { - deprecatedEndpoint.decRefLocked() - deprecatedEndpoint = nil - } - - return r - } - } else if deprecatedEndpoint == nil && r.tryIncRef() { - // We prefer an endpoint that is not deprecated, but we keep track of r in - // case n doesn't have any non-deprecated endpoints. - // - // If we end up finding a more preferred endpoint, r's reference count - // will be decremented when such an endpoint is found. - deprecatedEndpoint = r - } - } - - // n doesn't have any valid non-deprecated endpoints, so return - // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated - // endpoints either). - return deprecatedEndpoint -} - -// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC -// 6724 section 5). -type ipv6AddrCandidate struct { - ref *referencedNetworkEndpoint - scope header.IPv6AddressScope -} - -// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - n.mu.RLock() - ref := n.primaryIPv6EndpointRLocked(remoteAddr) + spoofing := n.mu.spoofing n.mu.RUnlock() - return ref -} - -// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -// -// n.mu MUST be read locked. -func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] - - if len(primaryAddrs) == 0 { - return nil - } - - // Create a candidate set of available addresses we can potentially use as a - // source address. - cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) - for _, r := range primaryAddrs { - // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoingRLocked() { - continue - } - - addr := r.ep.ID().LocalAddress - scope, err := header.ScopeForIPv6Address(addr) - if err != nil { - // Should never happen as we got r from the primary IPv6 endpoint list and - // ScopeForIPv6Address only returns an error if addr is not an IPv6 - // address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) - } - - cs = append(cs, ipv6AddrCandidate{ - ref: r, - scope: scope, - }) - } - - remoteScope, err := header.ScopeForIPv6Address(remoteAddr) - if err != nil { - // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) - } - - // Sort the addresses as per RFC 6724 section 5 rules 1-3. - // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. - sort.Slice(cs, func(i, j int) bool { - sa := cs[i] - sb := cs[j] - - // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.ref.ep.ID().LocalAddress == remoteAddr { - return true - } - if sb.ref.ep.ID().LocalAddress == remoteAddr { - return false - } - - // Prefer appropriate scope as per RFC 6724 section 5 rule 2. - if sa.scope < sb.scope { - return sa.scope >= remoteScope - } else if sb.scope < sa.scope { - return sb.scope < remoteScope - } - - // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. - if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { - // If sa is not deprecated, it is preferred over sb. - return sbDep - } - - // Prefer temporary addresses as per RFC 6724 section 5 rule 7. - if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { - return saTemp - } - - // sa and sb are equal, return the endpoint that is closest to the front of - // the primary endpoint list. - return i < j - }) - - // Return the most preferred address that can have its reference count - // incremented. - for _, c := range cs { - if r := c.ref; r.tryIncRef() { - return r - } - } - - return nil -} - -// hasPermanentAddrLocked returns true if n has a permanent (including currently -// tentative) address, addr. -func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] + ep, ok := n.networkEndpoints[protocol] if !ok { - return false + return nil } - kind := ref.getKind() - - return kind == permanent || kind == permanentTentative + return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } -type getRefBehaviour int +type getAddressBehaviour int const ( // spoofing indicates that the NIC's spoofing flag should be observed when - // getting a NIC's referenced network endpoint. - spoofing getRefBehaviour = iota + // getting a NIC's address endpoint. + spoofing getAddressBehaviour = iota // promiscuous indicates that the NIC's promiscuous flag should be observed - // when getting a NIC's referenced network endpoint. + // when getting a NIC's address endpoint. promiscuous ) -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) +func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, spoofing) +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) } -// getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. +// getAddressEpOrCreateTemp returns the address endpoint for the given protocol +// and address. // // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { n.mu.RLock() - var spoofingOrPromiscuous bool switch tempRef { case spoofing: @@ -607,267 +304,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous } - - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // An endpoint with this id exists, check if it can be used and return it. - if !ref.isAssignedRLocked(spoofingOrPromiscuous) { - n.mu.RUnlock() - return nil - } - - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the address is found in the NIC's subnets. - createTempEP := spoofingOrPromiscuous - if !createTempEP { - for _, sn := range n.mu.addressRanges { - // Skip the subnet address. - if address == sn.ID() { - continue - } - // For now just skip the broadcast address, until we support it. - // FIXME(b/137608825): Add support for sending/receiving directed - // (subnet) broadcast. - if address == sn.Broadcast() { - continue - } - if sn.Contains(address) { - createTempEP = true - break - } - } - } - n.mu.RUnlock() - - if !createTempEP { - return nil - } - - // Try again with the lock in exclusive mode. If we still can't get the - // endpoint, create a new "temporary" endpoint. It will only exist while - // there's a route through it. - n.mu.Lock() - ref := n.getRefOrCreateTempLocked(protocol, address, peb) - n.mu.Unlock() - return ref + return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb) } -/// getRefOrCreateTempLocked returns an existing endpoint for address or creates -/// and returns a temporary endpoint. -func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // No need to check the type as we are ok with expired endpoints at this - // point. - if ref.tryIncRef() { - return ref - } - // tryIncRef failing means the endpoint is scheduled to be removed once the - // lock is released. Remove it here so we can create a new (temporary) one. - // The removal logic waiting for the lock handles this case. - n.removeEndpointLocked(ref) +// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean +// is passed to indicate whether or not we should generate temporary endpoints. +func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + if ep, ok := n.networkEndpoints[protocol]; ok { + return ep.AcquireAssignedAddress(address, createTemp, peb) } - // Add a new temporary endpoint. - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return nil - } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, temporary, static, false) - return ref + return nil } -// addAddressLocked adds a new protocolAddress to n. -// -// If n already has the address in a non-permanent state, and the kind given is -// permanent, that address will be promoted in place and its properties set to -// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP addresses before adding them. - - // Sanity check. - id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.mu.endpoints[id]; ok { - // Endpoint already exists. - if kind != permanent { - return nil, tcpip.ErrDuplicateAddress - } - switch ref.getKind() { - case permanentTentative, permanent: - // The NIC already have a permanent endpoint with that address. - return nil, tcpip.ErrDuplicateAddress - case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect the new peb, - // configType and deprecated status. - if ref.tryIncRef() { - // TODO(b/147748385): Perform Duplicate Address Detection when promoting - // an IPv6 endpoint to permanent. - ref.setKind(permanent) - ref.deprecated = deprecated - ref.configType = configType - - refs := n.mu.primary[ref.protocol] - for i, r := range refs { - if r == ref { - switch peb { - case CanBePrimaryEndpoint: - return ref, nil - case FirstPrimaryEndpoint: - if i == 0 { - return ref, nil - } - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - case NeverPrimaryEndpoint: - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - return ref, nil - } - } - } - - n.insertPrimaryEndpointLocked(ref, peb) - - return ref, nil - } - // tryIncRef failing means the endpoint is scheduled to be removed once - // the lock is released. Remove it here so we can create a new - // (permanent) one. The removal logic waiting for the lock handles this - // case. - n.removeEndpointLocked(ref) - } - } - - netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] +// addAddress adds a new address to n, so that it starts accepting packets +// targeted at the given address (and network protocol). +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { + ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return nil, tcpip.ErrUnknownProtocol - } - - // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack) - if err != nil { - return nil, err - } - - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) - - // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process if the NIC is - // enabled. If the NIC is not enabled, DAD will be started when the NIC is - // enabled. - if isIPv6Unicast && kind == permanent { - kind = permanentTentative - } - - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, - configType: configType, - deprecated: deprecated, - } - - // Set up cache if link address resolution exists for this protocol. - if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { - ref.linkCache = n.stack - } - } - - // If we are adding an IPv6 unicast address, join the solicited-node - // multicast address. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) - if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { - return nil, err - } + return tcpip.ErrUnknownProtocol } - n.mu.endpoints[id] = ref - - n.insertPrimaryEndpointLocked(ref, peb) - - // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. - if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { - if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { - return nil, err - } + addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + if err == nil { + // We have no need for the address endpoint. + addressEndpoint.DecRef() } - - return ref, nil -} - -// AddAddress adds a new address to n, so that it starts accepting packets -// targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { - // Add the endpoint. - n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) - n.mu.Unlock() - return err } -// AllAddresses returns all addresses (primary and non-primary) associated with +// allPermanentAddresses returns all permanent addresses associated with // this NIC. -func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - - addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) - for nid, ref := range n.mu.endpoints { - // Don't include tentative, expired or temporary endpoints to - // avoid confusion and prevent the caller from using those. - switch ref.getKind() { - case permanentExpired, temporary: - continue +func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { + var addrs []tcpip.ProtocolAddress + for p, ep := range n.networkEndpoints { + for _, a := range ep.PermanentAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: nid.LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, - }) } return addrs } -// PrimaryAddresses returns the primary addresses associated with this NIC. -func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - +// primaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress - for proto, list := range n.mu.primary { - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints - // to avoid confusion and prevent the caller from using - // those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, - }) + for p, ep := range n.networkEndpoints { + for _, a := range ep.PrimaryAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } return addrs @@ -879,289 +363,135 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { - n.mu.RLock() - defer n.mu.RUnlock() - - list, ok := n.mu.primary[proto] + ep, ok := n.networkEndpoints[proto] if !ok { return tcpip.AddressWithPrefix{} } - var deprecatedEndpoint *referencedNetworkEndpoint - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints to avoid confusion - // and prevent the caller from using those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - if !ref.deprecated { - return tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - } - } - - if deprecatedEndpoint == nil { - deprecatedEndpoint = ref - } - } - - if deprecatedEndpoint != nil { - return tcpip.AddressWithPrefix{ - Address: deprecatedEndpoint.ep.ID().LocalAddress, - PrefixLen: deprecatedEndpoint.ep.PrefixLen(), - } - } - - return tcpip.AddressWithPrefix{} + return ep.MainAddress() } -// AddAddressRange adds a range of addresses to n, so that it starts accepting -// packets targeted at the given addresses and network protocol. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { - n.mu.Lock() - n.mu.addressRanges = append(n.mu.addressRanges, subnet) - n.mu.Unlock() -} - -// RemoveAddressRange removes the given address range from n. -func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { - n.mu.Lock() - - // Use the same underlying array. - tmp := n.mu.addressRanges[:0] - for _, sub := range n.mu.addressRanges { - if sub != subnet { - tmp = append(tmp, sub) +// removeAddress removes an address from n. +func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { + for _, ep := range n.networkEndpoints { + if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + continue + } else { + return err } } - n.mu.addressRanges = tmp - n.mu.Unlock() + return tcpip.ErrBadLocalAddress } -// AddressRanges returns the Subnets associated with this NIC. -func (n *NIC) AddressRanges() []tcpip.Subnet { - n.mu.RLock() - defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints)) - for nid := range n.mu.endpoints { - sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) - if err != nil { - // This should never happen as the mask has been carefully crafted to - // match the address. - panic("Invalid endpoint subnet: " + err.Error()) - } - sns = append(sns, sn) +func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { + if n.neigh == nil { + return nil, tcpip.ErrNotSupported } - return append(sns, n.mu.addressRanges...) -} -// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required -// by peb. -// -// n MUST be locked. -func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { - switch peb { - case CanBePrimaryEndpoint: - n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) - case FirstPrimaryEndpoint: - n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) - } + return n.neigh.entries(), nil } -func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := *r.ep.ID() - - // Nothing to do if the reference has already been replaced with a different - // one. This happens in the case where 1) this endpoint's ref count hit zero - // and was waiting (on the lock) to be removed and 2) the same address was - // re-added in the meantime by removing this endpoint from the list and - // adding a new one. - if n.mu.endpoints[id] != r { +func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { + if n.neigh == nil { return } - if r.getKind() == permanent { - panic("Reference count dropped to zero before being removed") - } + n.neigh.removeWaker(addr, w) +} - delete(n.mu.endpoints, id) - refs := n.mu.primary[r.protocol] - for i, ref := range refs { - if ref == r { - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - refs[len(refs)-1] = nil - break - } +func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } - r.ep.Close() -} - -func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { - n.mu.Lock() - n.removeEndpointLocked(r) - n.mu.Unlock() + n.neigh.addStaticEntry(addr, linkAddress) + return nil } -func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadLocalAddress - } - - kind := r.getKind() - if kind != permanent && kind != permanentTentative { - return tcpip.ErrBadLocalAddress +func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } - switch r.protocol { - case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) - default: - r.expireLocked() - return nil + if !n.neigh.removeEntry(addr) { + return tcpip.ErrBadAddress } + return nil } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { - addr := r.addrWithPrefix() - - isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) - - if isIPv6Unicast { - n.mu.ndp.stopDuplicateAddressDetection(addr.Address) - - // If we are removing an address generated via SLAAC, cleanup - // its SLAAC resources and notify the integrator. - switch r.configType { - case slaac: - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - case slaacTemp: - n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - } - } - - r.expireLocked() - - // At this point the endpoint is deleted. - - // If we are removing an IPv6 unicast address, leave the solicited-node - // multicast address. - // - // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node - // multicast group may be left by user action. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr.Address) - if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } +func (n *NIC) clearNeighbors() *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } + n.neigh.clear() return nil } -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - return n.removePermanentAddressLocked(addr) -} - // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.joinGroupLocked(protocol, addr) -} - -// joinGroupLocked adds a new endpoint for the given multicast address, if none -// exists yet. Otherwise it just increments its count. n MUST be locked before -// joinGroupLocked is called. -func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as // outlined in RFC 3810 section 5. - id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - if joins == 0 { - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return tcpip.ErrUnknownProtocol - } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported } - n.mu.mcastJoins[id] = joins + 1 - return nil + + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + _, err := gep.JoinGroup(addr) + return err } // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.leaveGroupLocked(addr, false /* force */) -} +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported + } -// leaveGroupLocked decrements the count for the given multicast address, and -// when it reaches zero removes the endpoint for this address. n MUST be locked -// before leaveGroupLocked is called. -// -// If force is true, then the count for the multicast addres is ignored and the -// endpoint will be removed immediately. -func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { - id := NetworkEndpointID{addr} - joins, ok := n.mu.mcastJoins[id] + gep, ok := ep.(GroupAddressableEndpoint) if !ok { - // There are no joins with this address on this NIC. - return tcpip.ErrBadLocalAddress + return tcpip.ErrNotSupported } - joins-- - if force || joins == 0 { - // There are no outstanding joins or we are forced to leave, clean up. - delete(n.mu.mcastJoins, id) - return n.removePermanentAddressLocked(addr) + if _, err := gep.LeaveGroup(addr); err != nil { + return err } - n.mu.mcastJoins[id] = joins return nil } // isInGroup returns true if n has joined the multicast group addr. func (n *NIC) isInGroup(addr tcpip.Address) bool { - n.mu.RLock() - joins := n.mu.mcastJoins[NetworkEndpointID{addr}] - n.mu.RUnlock() + for _, ep := range n.networkEndpoints { + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + continue + } - return joins != 0 + if gep.IsInGroup(addr) { + return true + } + } + + return false } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) +func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { + r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + defer r.Release() r.RemoteLinkAddress = remotelinkAddr - - ref.ep.HandlePacket(&r, pkt) - ref.decRef() + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -1172,7 +502,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // the ownership of the items is not retained by the caller. func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() - enabled := n.mu.enabled + enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { n.mu.RUnlock() @@ -1198,17 +528,15 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp local = n.linkEP.LinkAddress() } - // Are any packet sockets listening for this network protocol? + // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet type sockets that may be listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -1223,37 +551,42 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } if hasTransportHdr { + pkt.TransportProtocolNumber = transProtoNum // Parse the transport header if present. if state, ok := n.stack.transportProtocols[transProtoNum]; ok { state.proto.Parse(pkt) } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader) + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return + if n.stack.handleLocal && !n.IsLoopback() { + if r := n.getAddress(protocol, src); r != nil { + r.DecRef() + + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } } - // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if !n.IsLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. + n.stack.stats.IP.IPTablesPreroutingDropped.Increment() return } } - if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) + if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { + n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) return } @@ -1261,7 +594,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1269,25 +602,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.ref.nic - n.mu.RLock() - ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() - n.mu.RUnlock() - if ok { - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote - r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, pkt) - ref.decRef() - r.Release() - return + n := r.nic + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { + if n.isValidForOutgoing(addressEndpoint) { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote + r.RemoteAddress = src + // TODO(b/123449044): Update the source NIC as well. + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + addressEndpoint.DecRef() + r.Release() + return + } + + addressEndpoint.DecRef() } // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) @@ -1311,26 +645,39 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this - // by having lower layers explicity write each header instead of just - // pkt.Header. - - // pkt may have set its NetworkHeader and TransportHeader. If we're - // forwarding, we'll have to copy them into pkt.Header. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) - if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) - } - if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) - } - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } @@ -1341,11 +688,11 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return + return TransportPacketProtocolUnreachable } transProto := state.proto @@ -1355,52 +702,58 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. - if pkt.TransportHeader == nil { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + if pkt.TransportHeader().View().IsEmpty() { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() - return + // We consider a malformed transport packet handled because there is + // nothing the caller can do. + return TransportPacketHandled } - pkt.TransportHeader = transHeader - } else { - // This is either a bad packet or was re-assembled from fragments. - transProto.Parse(pkt) + } else if !transProto.Parse(pkt) { + n.stack.stats.MalformedRcvdPackets.Increment() + return TransportPacketHandled } } - if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() - return + return TransportPacketHandled } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} if n.stack.demux.deliverPacket(r, protocol, pkt, id) { - return + return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { if state.defaultHandler(r, id, pkt) { - return + return TransportPacketHandled } } - // We could not find an appropriate destination for this packet, so - // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { + // We could not find an appropriate destination for this packet so + // give the protocol specific error handler a chance to handle it. + // If it doesn't handle it then we should do so. + switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() + return TransportPacketHandled + case UnknownDestinationPacketUnhandled: + return TransportPacketDestinationPortUnreachable + case UnknownDestinationPacketHandled: + return TransportPacketHandled + default: + panic(fmt.Sprintf("unrecognized result from HandleUnknownDestinationPacket = %d", res)) } } @@ -1433,137 +786,42 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } } -// ID returns the identifier of n. +// ID implements NetworkInterface. func (n *NIC) ID() tcpip.NICID { return n.id } -// Name returns the name of n. +// Name implements NetworkInterface. func (n *NIC) Name() string { return n.name } -// Stack returns the instance of the Stack that owns this NIC. -func (n *NIC) Stack() *Stack { - return n.stack -} - -// LinkEndpoint returns the link endpoint of n. +// LinkEndpoint implements NetworkInterface. func (n *NIC) LinkEndpoint() LinkEndpoint { return n.linkEP } -// isAddrTentative returns true if addr is tentative on n. -// -// Note that if addr is not associated with n, then this function will return -// false. It will only return true if the address is associated with the NIC -// AND it is tentative. -func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - n.mu.RLock() - defer n.mu.RUnlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return false +// nudConfigs gets the NUD configurations for n. +func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { + if n.neigh == nil { + return NUDConfigurations{}, tcpip.ErrNotSupported } - - return ref.getKind() == permanentTentative + return n.neigh.config(), nil } -// dupTentativeAddrDetected attempts to inform n that a tentative addr is a -// duplicate on a link. +// setNUDConfigs sets the NUD configurations for n. // -// dupTentativeAddrDetected will remove the tentative address if it exists. If -// the address was generated via SLAAC, an attempt will be made to generate a -// new address. -func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadAddress - } - - if ref.getKind() != permanentTentative { - return tcpip.ErrInvalidEndpointState - } - - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a - // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { - return err - } - - prefix := ref.addrWithPrefix().Subnet() - - switch ref.configType { - case slaac: - n.mu.ndp.regenerateSLAACAddr(prefix) - case slaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } - + c.resetInvalidFields() + n.neigh.setConfig(c) return nil } -// setNDPConfigs sets the NDP configurations for n. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (n *NIC) setNDPConfigs(c NDPConfigurations) { - c.validate() - - n.mu.Lock() - n.mu.ndp.configs = c - n.mu.Unlock() -} - -// handleNDPRA handles an NDP Router Advertisement message that arrived on n. -func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.handleRA(ip, ra) -} - -type networkEndpointKind int32 - -const ( - // A permanentTentative endpoint is a permanent address that is not yet - // considered to be fully bound to an interface in the traditional - // sense. That is, the address is associated with a NIC, but packets - // destined to the address MUST NOT be accepted and MUST be silently - // dropped, and the address MUST NOT be used as a source address for - // outgoing packets. For IPv6, addresses will be of this kind until - // NDP's Duplicate Address Detection has resolved, or be deleted if - // the process results in detecting a duplicate address. - permanentTentative networkEndpointKind = iota - - // A permanent endpoint is created by adding a permanent address (vs. a - // temporary one) to the NIC. Its reference count is biased by 1 to avoid - // removal when no route holds a reference to it. It is removed by explicitly - // removing the permanent address from the NIC. - permanent - - // An expired permanent endpoint is a permanent endpoint that had its address - // removed from the NIC, and it is waiting to be removed once no more routes - // hold a reference to it. This is achieved by decreasing its reference count - // by 1. If its address is re-added before the endpoint is removed, its type - // changes back to permanent and its reference count increases by 1 again. - permanentExpired - - // A temporary endpoint is created for spoofing outgoing packets, or when in - // promiscuous mode and accepting incoming packets that don't match any - // permanent endpoint. Its reference count is not biased by 1 and the - // endpoint is removed immediately when no more route holds a reference to - // it. A temporary endpoint can be promoted to permanent if its address - // is added permanently. - temporary -) - func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1594,147 +852,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } -type networkEndpointConfigType int32 - -const ( - // A statically configured endpoint is an address that was added by - // some user-specified action (adding an explicit address, joining a - // multicast group). - static networkEndpointConfigType = iota - - // A SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4862 section 5.5.3. - slaac - - // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are - // not expected to be valid (or preferred) forever; hence the term temporary. - slaacTemp -) - -type referencedNetworkEndpoint struct { - ep NetworkEndpoint - nic *NIC - protocol tcpip.NetworkProtocolNumber - - // linkCache is set if link address resolution is enabled for this - // protocol. Set to nil otherwise. - linkCache LinkAddressCache - - // refs is counting references held for this endpoint. When refs hits zero it - // triggers the automatic removal of the endpoint from the NIC. - refs int32 - - // networkEndpointKind must only be accessed using {get,set}Kind(). - kind networkEndpointKind - - // configType is the method that was used to configure this endpoint. - // This must never change except during endpoint creation and promotion to - // permanent. - configType networkEndpointConfigType - - // deprecated indicates whether or not the endpoint should be considered - // deprecated. That is, when deprecated is true, other endpoints that are not - // deprecated should be preferred. - deprecated bool -} - -func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { - return tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } -} - -func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { - return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) -} - -func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { - atomic.StoreInt32((*int32)(&r.kind), int32(kind)) -} - // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address) // has been removed) unless the NIC is in spoofing mode, or temporary. -func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - r.nic.mu.RLock() - defer r.nic.mu.RUnlock() - - return r.isValidForOutgoingRLocked() -} - -// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires -// r.nic.mu to be read locked. -func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - if !r.nic.mu.enabled { - return false - } - - return r.isAssignedRLocked(r.nic.mu.spoofing) -} - -// isAssignedRLocked returns true if r is considered to be assigned to the NIC. -// -// r.nic.mu must be read locked. -func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { - switch r.getKind() { - case permanentTentative: - return false - case permanentExpired: - return spoofingOrPromiscuous - default: - return true - } -} - -// expireLocked decrements the reference count and marks the permanent endpoint -// as expired. -func (r *referencedNetworkEndpoint) expireLocked() { - r.setKind(permanentExpired) - r.decRefLocked() -} - -// decRef decrements the ref count and cleans up the endpoint once it reaches -// zero. -func (r *referencedNetworkEndpoint) decRef() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpoint(r) - } -} - -// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpointLocked(r) - } -} - -// incRef increments the ref count. It must only be called when the caller is -// known to be holding a reference to the endpoint, otherwise tryIncRef should -// be used. -func (r *referencedNetworkEndpoint) incRef() { - atomic.AddInt32(&r.refs, 1) -} - -// tryIncRef attempts to increment the ref count from n to n+1, but only if n is -// not zero. That is, it will increment the count if the endpoint is still -// alive, and do nothing if it has already been clean up. -func (r *referencedNetworkEndpoint) tryIncRef() bool { - for { - v := atomic.LoadInt32(&r.refs) - if v == 0 { - return false - } - - if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { - return true - } - } -} - -// stack returns the Stack instance that owns the underlying endpoint. -func (r *referencedNetworkEndpoint) stack() *Stack { - return r.nic.stack +func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + return n.Enabled() && ep.IsAssigned(spoofing) } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 31f865260..fdd49b77f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,88 +15,40 @@ package stack import ( - "math" "testing" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) -var _ LinkEndpoint = (*testLinkEndpoint)(nil) +var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) +var _ NDPEndpoint = (*testIPv6Endpoint)(nil) -// A LinkEndpoint that throws away outgoing packets. +// An IPv6 NetworkEndpoint that throws away outgoing packets. // -// We use this instead of the channel endpoint as the channel package depends on +// We use this instead of ipv6.endpoint because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. -type testLinkEndpoint struct { - dispatcher NetworkDispatcher -} - -// Attach implements LinkEndpoint.Attach. -func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements LinkEndpoint.IsAttached. -func (e *testLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements LinkEndpoint.MTU. -func (*testLinkEndpoint) MTU() uint32 { - return math.MaxUint16 -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { - return CapabilityResolutionRequired -} +type testIPv6Endpoint struct { + AddressableEndpointState -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*testLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} + nicID tcpip.NICID + linkEP LinkEndpoint + protocol *testIPv6Protocol -// LinkAddress returns the link address of this endpoint. -func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" + invalidatedRtr tcpip.Address } -// Wait implements LinkEndpoint.Wait. -func (*testLinkEndpoint) Wait() {} - -// WritePacket implements LinkEndpoint.WritePacket. -func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) Enable() *tcpip.Error { return nil } -// WritePackets implements LinkEndpoint.WritePackets. -func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported +func (*testIPv6Endpoint) Enabled() bool { + return true } -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - nicID tcpip.NICID - id NetworkEndpointID - prefixLen int - linkEP LinkEndpoint - protocol *testIPv6Protocol -} +func (*testIPv6Endpoint) Disable() {} // DefaultTTL implements NetworkEndpoint.DefaultTTL. func (*testIPv6Endpoint) DefaultTTL() uint8 { @@ -108,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 { return e.linkEP.MTU() - header.IPv6MinimumSize } -// Capabilities implements NetworkEndpoint.Capabilities. -func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize @@ -136,33 +83,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip return tcpip.ErrNotSupported } -// ID implements NetworkEndpoint.ID. -func (e *testIPv6Endpoint) ID() *NetworkEndpointID { - return &e.id -} - -// PrefixLen implements NetworkEndpoint.PrefixLen. -func (e *testIPv6Endpoint) PrefixLen() int { - return e.prefixLen -} - -// NICID implements NetworkEndpoint.NICID. -func (e *testIPv6Endpoint) NICID() tcpip.NICID { - return e.nicID -} - // HandlePacket implements NetworkEndpoint.HandlePacket. func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { } // Close implements NetworkEndpoint.Close. -func (*testIPv6Endpoint) Close() {} +func (e *testIPv6Endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} // NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } +func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.invalidatedRtr = rtr +} + var _ NetworkProtocol = (*testIPv6Protocol)(nil) // An IPv6 NetworkProtocol that supports the bare minimum to make a stack @@ -194,23 +132,23 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { - return &testIPv6Endpoint{ - nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - protocol: p, - }, nil +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { + e := &testIPv6Endpoint{ + nicID: nic.ID(), + linkEP: nic.LinkEndpoint(), + protocol: p, + } + e.AddressableEndpointState.Init(e) + return e } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return nil } @@ -233,7 +171,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { return nil } @@ -245,38 +183,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd return "", false } -// Test the race condition where a NIC is removed and an RS timer fires at the -// same time. -func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { - const ( - nicID = 1 - - maxRtrSolicitations = 5 - ) - - e := testLinkEndpoint{} - s := New(Options{ - NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, - NDPConfigs: NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: minimumRtrSolicitationInterval, - }, - }) - - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) - } - - s.mu.Lock() - // Wait for the router solicitation timer to fire and block trying to obtain - // the stack lock when doing link address resolution. - time.Sleep(minimumRtrSolicitationInterval * 2) - if err := s.removeNICLocked(nicID); err != nil { - t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) - } - s.mu.Unlock() -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. @@ -301,7 +207,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go new file mode 100644 index 000000000..e1ec15487 --- /dev/null +++ b/pkg/tcpip/stack/nud.go @@ -0,0 +1,466 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "math" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // defaultBaseReachableTime is the default base duration for computing the + // random reachable time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of a uniformly distributed random value between the minimum and + // maximum random factors, multiplied by the base reachable time. Using a + // random component eliminates the possibility that Neighbor Unreachability + // Detection messages will synchronize with each other. + // + // Default taken from REACHABLE_TIME of RFC 4861 section 10. + defaultBaseReachableTime = 30 * time.Second + + // minimumBaseReachableTime is the minimum base duration for computing the + // random reachable time. + // + // Minimum = 1ms + minimumBaseReachableTime = time.Millisecond + + // defaultMinRandomFactor is the default minimum value of the random factor + // used for computing reachable time. + // + // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10. + defaultMinRandomFactor = 0.5 + + // defaultMaxRandomFactor is the default maximum value of the random factor + // used for computing reachable time. + // + // The default value depends on the value of MinRandomFactor. + // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10, + // the value from the RFC will be used; otherwise, the default is + // MinRandomFactor multiplied by three. + defaultMaxRandomFactor = 1.5 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + + // defaultDelayFirstProbeTime is the default duration to wait for a + // non-Neighbor-Discovery related protocol to reconfirm reachability after + // entering the DELAY state. After this time, a reachability probe will be + // sent and the entry will transition to the PROBE state. + // + // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10. + defaultDelayFirstProbeTime = 5 * time.Second + + // defaultMaxMulticastProbes is the default number of reachabililty probes + // to send before concluding negative reachability and deleting the neighbor + // entry from the INCOMPLETE state. + // + // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10. + defaultMaxMulticastProbes = 3 + + // defaultMaxUnicastProbes is the default number of reachability probes to + // send before concluding retransmission from within the PROBE state should + // cease and the entry SHOULD be deleted. + // + // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10. + defaultMaxUnicastProbes = 3 + + // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD + // delay sending a response for a random time between 0 and this time, if the + // target address is an anycast address. + // + // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10. + defaultMaxAnycastDelayTime = time.Second + + // defaultMaxReachbilityConfirmations is the default amount of unsolicited + // reachability confirmation messages a node MAY send to all-node multicast + // address when it determines its link-layer address has changed. + // + // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. + defaultMaxReachbilityConfirmations = 3 + + // defaultUnreachableTime is the default duration for how long an entry will + // remain in the FAILED state before being removed from the neighbor cache. + // + // Note, there is no equivalent protocol constant defined in RFC 4861. It + // leaves the specifics of any garbage collection mechanism up to the + // implementation. + defaultUnreachableTime = 5 * time.Second +) + +// NUDDispatcher is the interface integrators of netstack must implement to +// receive and handle NUD related events. +type NUDDispatcher interface { + // OnNeighborAdded will be called when a new entry is added to a NIC's (with + // ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) + // neighbor table changes state and/or link address. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborRemoved will be called when an entry is removed from a NIC's + // (with ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) +} + +// ReachabilityConfirmationFlags describes the flags used within a reachability +// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, +// respectively). +type ReachabilityConfirmationFlags struct { + // Solicited indicates that the advertisement was sent in response to a + // reachability probe. + Solicited bool + + // Override indicates that the reachability confirmation should override an + // existing neighbor cache entry and update the cached link-layer address. + // When Override is not set the confirmation will not update a cached + // link-layer address, but will update an existing neighbor cache entry for + // which no link-layer address is known. + Override bool + + // IsRouter indicates that the sender is a router. + IsRouter bool +} + +// NUDHandler communicates external events to the Neighbor Unreachability +// Detection state machine, which is implemented per-interface. This is used by +// network endpoints to inform the Neighbor Cache of probes and confirmations. +type NUDHandler interface { + // HandleProbe processes an incoming neighbor probe (e.g. ARP request or + // Neighbor Solicitation for ARP or NDP, respectively). Validation of the + // probe needs to be performed before calling this function since the + // Neighbor Cache doesn't have access to view the NIC's assigned addresses. + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + + // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP + // reply or Neighbor Advertisement for ARP or NDP, respectively). + HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) + + // HandleUpperLevelConfirmation processes an incoming upper-level protocol + // (e.g. TCP acknowledgements) reachability confirmation. + HandleUpperLevelConfirmation(addr tcpip.Address) +} + +// NUDConfigurations is the NUD configurations for the netstack. This is used +// by the neighbor cache to operate the NUD state machine on each device in the +// local network. +type NUDConfigurations struct { + // BaseReachableTime is the base duration for computing the random reachable + // time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of uniformly distributed random value between minRandomFactor and + // maxRandomFactor multiplied by baseReachableTime. Using a random component + // eliminates the possibility that Neighbor Unreachability Detection messages + // will synchronize with each other. + // + // After this time, a neighbor entry will transition from REACHABLE to STALE + // state. + // + // Must be greater than 0. + BaseReachableTime time.Duration + + // LearnBaseReachableTime enables learning BaseReachableTime during runtime + // from the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option. + LearnBaseReachableTime bool + + // MinRandomFactor is the minimum value of the random factor used for + // computing reachable time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be greater than 0. + MinRandomFactor float32 + + // MaxRandomFactor is the maximum value of the random factor used for + // computing reachabile time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be great than or equal to MinRandomFactor. + MaxRandomFactor float32 + + // RetransmitTimer is the duration between retransmission of reachability + // probes in the PROBE state. + RetransmitTimer time.Duration + + // LearnRetransmitTimer enables learning RetransmitTimer during runtime from + // the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option. + LearnRetransmitTimer bool + + // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery + // related protocol to reconfirm reachability after entering the DELAY state. + // After this time, a reachability probe will be sent and the entry will + // transition to the PROBE state. + // + // Must be greater than 0. + DelayFirstProbeTime time.Duration + + // MaxMulticastProbes is the number of reachability probes to send before + // concluding negative reachability and deleting the neighbor entry from the + // INCOMPLETE state. + // + // Must be greater than 0. + MaxMulticastProbes uint32 + + // MaxUnicastProbes is the number of reachability probes to send before + // concluding retransmission from within the PROBE state should cease and + // entry SHOULD be deleted. + // + // Must be greater than 0. + MaxUnicastProbes uint32 + + // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a + // response for a random time between 0 and this time, if the target address + // is an anycast address. + // + // TODO(gvisor.dev/issue/2242): Use this option when sending solicited + // neighbor confirmations to anycast addresses and proxying neighbor + // confirmations. + MaxAnycastDelayTime time.Duration + + // MaxReachabilityConfirmations is the number of unsolicited reachability + // confirmation messages a node MAY send to all-node multicast address when + // it determines its link-layer address has changed. + // + // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD + // configuration option is necessary. + MaxReachabilityConfirmations uint32 + + // UnreachableTime describes how long an entry will remain in the FAILED + // state before being removed from the neighbor cache. + UnreachableTime time.Duration +} + +// DefaultNUDConfigurations returns a NUDConfigurations populated with default +// values defined by RFC 4861 section 10. +func DefaultNUDConfigurations() NUDConfigurations { + return NUDConfigurations{ + BaseReachableTime: defaultBaseReachableTime, + LearnBaseReachableTime: true, + MinRandomFactor: defaultMinRandomFactor, + MaxRandomFactor: defaultMaxRandomFactor, + RetransmitTimer: defaultRetransmitTimer, + LearnRetransmitTimer: true, + DelayFirstProbeTime: defaultDelayFirstProbeTime, + MaxMulticastProbes: defaultMaxMulticastProbes, + MaxUnicastProbes: defaultMaxUnicastProbes, + MaxAnycastDelayTime: defaultMaxAnycastDelayTime, + MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, + UnreachableTime: defaultUnreachableTime, + } +} + +// resetInvalidFields modifies an invalid NDPConfigurations with valid values. +// If invalid values are present in c, the corresponding default values will be +// used instead. This is needed to check, and conditionally fix, user-specified +// NUDConfigurations. +func (c *NUDConfigurations) resetInvalidFields() { + if c.BaseReachableTime < minimumBaseReachableTime { + c.BaseReachableTime = defaultBaseReachableTime + } + if c.MinRandomFactor <= 0 { + c.MinRandomFactor = defaultMinRandomFactor + } + if c.MaxRandomFactor < c.MinRandomFactor { + c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor) + } + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } + if c.DelayFirstProbeTime == 0 { + c.DelayFirstProbeTime = defaultDelayFirstProbeTime + } + if c.MaxMulticastProbes == 0 { + c.MaxMulticastProbes = defaultMaxMulticastProbes + } + if c.MaxUnicastProbes == 0 { + c.MaxUnicastProbes = defaultMaxUnicastProbes + } + if c.UnreachableTime == 0 { + c.UnreachableTime = defaultUnreachableTime + } +} + +// calcMaxRandomFactor calculates the maximum value of the random factor used +// for computing reachable time. This function is necessary for when the +// default specified in RFC 4861 section 10 is less than the current +// MinRandomFactor. +// +// Assumes minRandomFactor is positive since validation of the minimum value +// should come before the validation of the maximum. +func calcMaxRandomFactor(minRandomFactor float32) float32 { + if minRandomFactor > defaultMaxRandomFactor { + return minRandomFactor * 3 + } + return defaultMaxRandomFactor +} + +// A Rand is a source of random numbers. +type Rand interface { + // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). + Float32() float32 +} + +// NUDState stores states needed for calculating reachable time. +type NUDState struct { + rng Rand + + // mu protects the fields below. + // + // It is necessary for NUDState to handle its own locking since neighbor + // entries may access the NUD state from within the goroutine spawned by + // time.AfterFunc(). This goroutine may run concurrently with the main + // process for controlling the neighbor cache and would otherwise introduce + // race conditions if NUDState was not locked properly. + mu sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 +} + +// NewNUDState returns new NUDState using c as configuration and the specified +// random number generator for use in recomputing ReachableTime. +func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { + s := &NUDState{ + rng: rng, + } + s.config = c + return s +} + +// Config returns the NUD configuration. +func (s *NUDState) Config() NUDConfigurations { + s.mu.RLock() + defer s.mu.RUnlock() + return s.config +} + +// SetConfig replaces the existing NUD configurations with c. +func (s *NUDState) SetConfig(c NUDConfigurations) { + s.mu.Lock() + defer s.mu.Unlock() + s.config = c +} + +// ReachableTime returns the duration to wait for a REACHABLE entry to +// transition into STALE after inactivity. This value is recalculated for new +// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the +// algorithm defined in RFC 4861 section 6.3.2. +func (s *NUDState) ReachableTime() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + + if time.Now().After(s.expiration) || + s.config.BaseReachableTime != s.prevBaseReachableTime || + s.config.MinRandomFactor != s.prevMinRandomFactor || + s.config.MaxRandomFactor != s.prevMaxRandomFactor { + return s.recomputeReachableTimeLocked() + } + return s.reachableTime +} + +// recomputeReachableTimeLocked forces a recalculation of ReachableTime using +// the algorithm defined in RFC 4861 section 6.3.2. +// +// This SHOULD automatically be invoked during certain situations, as per +// RFC 4861 section 6.3.4: +// +// If the received Reachable Time value is non-zero, the host SHOULD set its +// BaseReachableTime variable to the received value. If the new value +// differs from the previous value, the host SHOULD re-compute a new random +// ReachableTime value. ReachableTime is computed as a uniformly +// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR +// times the BaseReachableTime. Using a random component eliminates the +// possibility that Neighbor Unreachability Detection messages will +// synchronize with each other. +// +// In most cases, the advertised Reachable Time value will be the same in +// consecutive Router Advertisements, and a host's BaseReachableTime rarely +// changes. In such cases, an implementation SHOULD ensure that a new +// random value gets re-computed at least once every few hours. +// +// s.mu MUST be locked for writing. +func (s *NUDState) recomputeReachableTimeLocked() time.Duration { + s.prevBaseReachableTime = s.config.BaseReachableTime + s.prevMinRandomFactor = s.config.MinRandomFactor + s.prevMaxRandomFactor = s.config.MaxRandomFactor + + randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + + // Check for overflow, given that minRandomFactor and maxRandomFactor are + // guaranteed to be positive numbers. + if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { + s.reachableTime = time.Duration(math.MaxInt64) + } else if randomFactor == 1 { + // Avoid loss of precision when a large base reachable time is used. + s.reachableTime = s.config.BaseReachableTime + } else { + reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) + s.reachableTime = time.Duration(reachableTime) + } + + s.expiration = time.Now().Add(2 * time.Hour) + return s.reachableTime +} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go new file mode 100644 index 000000000..8cffb9fc6 --- /dev/null +++ b/pkg/tcpip/stack/nud_test.go @@ -0,0 +1,807 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 + defaultMaxAnycastDelayTime = time.Second + defaultMaxReachbilityConfirmations = 3 + defaultUnreachableTime = 5 * time.Second + + defaultFakeRandomNum = 0.5 +) + +// fakeRand is a deterministic random number generator. +type fakeRand struct { + num float32 +} + +var _ stack.Rand = (*fakeRand)(nil) + +func (f *fakeRand) Float32() float32 { + return f.num +} + +// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NUD configurations using an invalid NICID. +func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + UseNeighborCache: true, + }) + + // No NIC with ID 1 yet. + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to retrieve NUD configurations when the +// stack doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to set NUD configurations when the stack +// doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + } +} + +// TestDefaultNUDConfigurationIsValid verifies that calling +// resetInvalidFields() on the result of DefaultNUDConfigurations() does not +// change anything. DefaultNUDConfigurations() should return a valid +// NUDConfigurations. +func TestDefaultNUDConfigurations(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + c, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got, want := c, stack.DefaultNUDConfigurations(); got != want { + t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + } +} + +func TestNUDConfigurationsBaseReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + baseReachableTime: 0, + want: defaultBaseReachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + baseReachableTime: time.Millisecond, + want: time.Millisecond, + }, + { + name: "MoreThanDefaultBaseReachableTime", + baseReachableTime: 2 * defaultBaseReachableTime, + want: 2 * defaultBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = test.baseReachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.BaseReachableTime; got != test.want { + t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMinRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: -1, + want: defaultMinRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: 0, + want: defaultMinRandomFactor, + }, + // Valid cases + { + name: "MoreThanZero", + minRandomFactor: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MinRandomFactor; got != test.want { + t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + maxRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: -1, + want: defaultMaxRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: 0, + want: defaultMaxRandomFactor, + }, + { + name: "LessThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 0.99, + want: defaultMaxRandomFactor, + }, + { + name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", + minRandomFactor: defaultMaxRandomFactor * 2, + maxRandomFactor: defaultMaxRandomFactor, + want: defaultMaxRandomFactor * 6, + }, + // Valid cases + { + name: "EqualToMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor, + want: defaultMinRandomFactor, + }, + { + name: "MoreThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 1.1, + want: defaultMinRandomFactor * 1.1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxRandomFactor; got != test.want { + t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsRetransmitTimer(t *testing.T) { + tests := []struct { + name string + retransmitTimer time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + retransmitTimer: 0, + want: defaultRetransmitTimer, + }, + { + name: "LessThanMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer - time.Nanosecond, + want: defaultRetransmitTimer, + }, + // Valid cases + { + name: "EqualToMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer, + want: minimumBaseReachableTime, + }, + { + name: "LargetThanMinimumRetransmitTimer", + retransmitTimer: 2 * minimumBaseReachableTime, + want: 2 * minimumBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.RetransmitTimer = test.retransmitTimer + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.RetransmitTimer; got != test.want { + t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { + tests := []struct { + name string + delayFirstProbeTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + delayFirstProbeTime: 0, + want: defaultDelayFirstProbeTime, + }, + // Valid cases + { + name: "MoreThanZero", + delayFirstProbeTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.DelayFirstProbeTime = test.delayFirstProbeTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.DelayFirstProbeTime; got != test.want { + t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { + tests := []struct { + name string + maxMulticastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxMulticastProbes: 0, + want: defaultMaxMulticastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxMulticastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxMulticastProbes = test.maxMulticastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxMulticastProbes; got != test.want { + t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { + tests := []struct { + name string + maxUnicastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxUnicastProbes: 0, + want: defaultMaxUnicastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxUnicastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxUnicastProbes = test.maxUnicastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxUnicastProbes; got != test.want { + t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsUnreachableTime(t *testing.T) { + tests := []struct { + name string + unreachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + unreachableTime: 0, + want: defaultUnreachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + unreachableTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.UnreachableTime = test.unreachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NUDConfigs: c, + UseNeighborCache: true, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.UnreachableTime; got != test.want { + t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +// TestNUDStateReachableTime verifies the correctness of the ReachableTime +// computation. +func TestNUDStateReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "AllZeros", + baseReachableTime: 0, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMaxRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMinRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 1, + want: time.Duration(defaultFakeRandomNum * float32(time.Second)), + }, + { + name: "FractionalRandomFactor", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 0.001, + maxRandomFactor: 0.002, + want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), + }, + { + name: "MinAndMaxRandomFactorsEqual", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Second, + }, + { + name: "MinAndMaxRandomFactorsDifferent", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 2, + want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), + }, + { + name: "MaxInt64", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Duration(math.MaxInt64), + }, + { + name: "Overflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1.5, + maxRandomFactor: 1.5, + want: time.Duration(math.MaxInt64), + }, + { + name: "DoubleOverflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 2.5, + maxRandomFactor: 2.5, + want: time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.NUDConfigurations{ + BaseReachableTime: test.baseReachableTime, + MinRandomFactor: test.minRandomFactor, + MaxRandomFactor: test.maxRandomFactor, + } + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} + +// TestNUDStateRecomputeReachableTime exercises the ReachableTime function +// twice to verify recomputation of reachable time when the min random factor, +// max random factor, or base reachable time changes. +func TestNUDStateRecomputeReachableTime(t *testing.T) { + const defaultBase = time.Second + const defaultMin = 2.0 * defaultMaxRandomFactor + const defaultMax = 3.0 * defaultMaxRandomFactor + + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "BaseReachableTime", + baseReachableTime: 2 * defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMax, + want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + { + name: "MinRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMax, + maxRandomFactor: defaultMax, + want: time.Duration(defaultMax * float32(defaultBase)), + }, + { + name: "MaxRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMin, + want: time.Duration(defaultMin * float32(defaultBase)), + }, + { + name: "BothRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), + }, + { + name: "BaseReachableTimeAndBothRandomFactors", + baseReachableTime: 2 * defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = defaultBase + c.MinRandomFactor = defaultMin + c.MaxRandomFactor = defaultMax + + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + old := s.ReachableTime() + + if got, want := s.ReachableTime(), old; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Check for recomputation when changing the min random factor, the max + // random factor, the base reachability time, or any permutation of those + // three options. + c.BaseReachableTime = test.baseReachableTime + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + s.SetConfig(c) + + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Verify that ReachableTime isn't recomputed when none of the + // configuration options change. The random factor is changed so that if + // a recompution were to occur, ReachableTime would change. + rng.num = defaultFakeRandomNum / 2.0 + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1b5da6017..105583c49 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,52 +14,83 @@ package stack import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + // A PacketBuffer contains all the data of a network packet. // // As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. type PacketBuffer struct { - _ noCopy + _ sync.NoCopy // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. + // + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. Data buffer.VectorisedView - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. - Header buffer.Prependable + // headers stores metadata about each header. + headers [numHeaderType]headerInfo - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. - // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable + + // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() + // returns false. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber + + // TransportProtocol is only valid if it is non zero. + // TODO(gvisor.dev/issue/3810): This and the network protocol number should + // be moved into the headerinfo. This should resolve the validity issue. + TransportProtocolNumber tcpip.TransportProtocolNumber // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. @@ -71,45 +102,220 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. - EgressRoute *Route - GSOOptions *GSO - NetworkProtocolNumber tcpip.NetworkProtocolNumber + EgressRoute *Route + GSOOptions *GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. // -// FIXME(b/153685824): Data gets copied but not other header references. +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - Header: pk.Header, - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + newPk := &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, } + return newPk } -// noCopy may be embedded into structs which must not be copied -// after the first use. +// Network returns the network header as a header.Network. // -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} +// Network should only be called when NetworkHeader has been set. +func (pk *PacketBuffer) Network() header.Network { + switch netProto := pk.NetworkProtocolNumber; netProto { + case header.IPv4ProtocolNumber: + return header.IPv4(pk.NetworkHeader().View()) + case header.IPv6ProtocolNumber: + return header.IPv6(pk.NetworkHeader().View()) + default: + panic(fmt.Sprintf("unknown network protocol number %d", netProto)) + } +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} -func (*noCopy) Unlock() {} +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) + } + return append(v, h.pk.Data.ToView()...) +} diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5cbc946b6..be9bd8042 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,9 +15,12 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +54,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) @@ -121,6 +127,26 @@ type PacketEndpoint interface { HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } +// UnknownDestinationPacketDisposition enumerates the possible return vaues from +// HandleUnknownDestinationPacket(). +type UnknownDestinationPacketDisposition int + +const ( + // UnknownDestinationPacketMalformed denotes that the packet was malformed + // and no further processing should be attempted other than updating + // statistics. + UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota + + // UnknownDestinationPacketUnhandled tells the caller that the packet was + // well formed but that the issue was not handled and the stack should take + // the default action. + UnknownDestinationPacketUnhandled + + // UnknownDestinationPacketHandled tells the caller that it should do + // no further processing. + UnknownDestinationPacketHandled +) + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { @@ -128,10 +154,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -143,24 +169,22 @@ type TransportProtocol interface { ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this - // protocol but that don't match any existing endpoint. For example, - // it is targeted at a port that have no listeners. - // - // The return value indicates whether the packet was well-formed (for - // stats purposes only). + // protocol that don't match any existing endpoint. For example, + // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // the issue. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -175,6 +199,25 @@ type TransportProtocol interface { Parse(pkt *PacketBuffer) (ok bool) } +// TransportPacketDisposition is the result from attempting to deliver a packet +// to the transport layer. +type TransportPacketDisposition int + +const ( + // TransportPacketHandled indicates that a transport packet was handled by the + // transport layer and callers need not take any further action. + TransportPacketHandled TransportPacketDisposition = iota + + // TransportPacketProtocolUnreachable indicates that the transport + // protocol requested in the packet is not supported. + TransportPacketProtocolUnreachable + + // TransportPacketDestinationPortUnreachable indicates that there weren't any + // listeners interested in the packet and the transport protocol has no means + // to notify the sender. + TransportPacketDestinationPortUnreachable +) + // TransportDispatcher contains the methods used by the network stack to deliver // packets to the appropriate transport endpoint after it has been handled by // the network layer. @@ -185,7 +228,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -222,9 +265,253 @@ type NetworkHeaderParams struct { TOS uint8 } +// GroupAddressableEndpoint is an endpoint that supports group addressing. +// +// An endpoint is considered to support group addressing when one or more +// endpoints may associate themselves with the same identifier (group address). +type GroupAddressableEndpoint interface { + // JoinGroup joins the spcified group. + // + // Returns true if the group was newly joined. + JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + + // LeaveGroup attempts to leave the specified group. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. + LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + + // IsInGroup returns true if the endpoint is a member of the specified group. + IsInGroup(group tcpip.Address) bool +} + +// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary +// behavior. +type PrimaryEndpointBehavior int + +const ( + // CanBePrimaryEndpoint indicates the endpoint can be used as a primary + // endpoint for new connections with no local address. This is the + // default when calling NIC.AddAddress. + CanBePrimaryEndpoint PrimaryEndpointBehavior = iota + + // FirstPrimaryEndpoint indicates the endpoint should be the first + // primary endpoint considered. If there are multiple endpoints with + // this behavior, they are ordered by recency. + FirstPrimaryEndpoint + + // NeverPrimaryEndpoint indicates the endpoint should never be a + // primary endpoint. + NeverPrimaryEndpoint +) + +// AddressConfigType is the method used to add an address. +type AddressConfigType int + +const ( + // AddressConfigStatic is a statically configured address endpoint that was + // added by some user-specified action (adding an explicit address, joining a + // multicast group). + AddressConfigStatic AddressConfigType = iota + + // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862 + // section 5.5.3. + AddressConfigSlaac + + // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as + // per RFC 4941. Temporary SLAAC addresses are short-lived and are not + // to be valid (or preferred) forever; hence the term temporary. + AddressConfigSlaacTemp +) + +// AssignableAddressEndpoint is a reference counted address endpoint that may be +// assigned to a NetworkEndpoint. +type AssignableAddressEndpoint interface { + // AddressWithPrefix returns the endpoint's address. + AddressWithPrefix() tcpip.AddressWithPrefix + + // IsAssigned returns whether or not the endpoint is considered bound + // to its NetworkEndpoint. + IsAssigned(allowExpired bool) bool + + // IncRef increments this endpoint's reference count. + // + // Returns true if it was successfully incremented. If it returns false, then + // the endpoint is considered expired and should no longer be used. + IncRef() bool + + // DecRef decrements this endpoint's reference count. + DecRef() +} + +// AddressEndpoint is an endpoint representing an address assigned to an +// AddressableEndpoint. +type AddressEndpoint interface { + AssignableAddressEndpoint + + // GetKind returns the address kind for this endpoint. + GetKind() AddressKind + + // SetKind sets the address kind for this endpoint. + SetKind(AddressKind) + + // ConfigType returns the method used to add the address. + ConfigType() AddressConfigType + + // Deprecated returns whether or not this endpoint is deprecated. + Deprecated() bool + + // SetDeprecated sets this endpoint's deprecated status. + SetDeprecated(bool) +} + +// AddressKind is the kind of of an address. +// +// See the values of AddressKind for more details. +type AddressKind int + +const ( + // PermanentTentative is a permanent address endpoint that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses are of this kind until NDP's + // Duplicate Address Detection (DAD) resolves. If DAD fails, the address + // is removed. + PermanentTentative AddressKind = iota + + // Permanent is a permanent endpoint (vs. a temporary one) assigned to the + // NIC. Its reference count is biased by 1 to avoid removal when no route + // holds a reference to it. It is removed by explicitly removing the address + // from the NIC. + Permanent + + // PermanentExpired is a permanent endpoint that had its address removed from + // the NIC, and it is waiting to be removed once no references to it are held. + // + // If the address is re-added before the endpoint is removed, its type + // changes back to Permanent. + PermanentExpired + + // Temporary is an endpoint, created on a one-off basis to temporarily + // consider the NIC bound an an address that it is not explictiy bound to + // (such as a permanent address). Its reference count must not be biased by 1 + // so that the address is removed immediately when references to it are no + // longer held. + // + // A temporary endpoint may be promoted to permanent if the address is added + // permanently. + Temporary +) + +// IsPermanent returns true if the AddressKind represents a permanent address. +func (k AddressKind) IsPermanent() bool { + switch k { + case Permanent, PermanentTentative: + return true + case Temporary, PermanentExpired: + return false + default: + panic(fmt.Sprintf("unrecognized address kind = %d", k)) + } +} + +// AddressableEndpoint is an endpoint that supports addressing. +// +// An endpoint is considered to support addressing when the endpoint may +// associate itself with an identifier (address). +type AddressableEndpoint interface { + // AddAndAcquirePermanentAddress adds the passed permanent address. + // + // Returns tcpip.ErrDuplicateAddress if the address exists. + // + // Acquires and returns the AddressEndpoint for the added address. + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + + // RemovePermanentAddress removes the passed address if it is a permanent + // address. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // permanent address. + RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + + // MainAddress returns the endpoint's primary permanent address. + MainAddress() tcpip.AddressWithPrefix + + // AcquireAssignedAddress returns an address endpoint for the passed address + // that is considered bound to the endpoint, optionally creating a temporary + // endpoint if requested and no existing address exists. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if the specified address is not local to this endpoint. + AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint + + // AcquireOutgoingPrimaryAddress returns a primary address that may be used as + // a source address when sending packets to the passed remote address. + // + // If allowExpired is true, expired addresses may be returned. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if a primary address is not available. + AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint + + // PrimaryAddresses returns the primary addresses. + PrimaryAddresses() []tcpip.AddressWithPrefix + + // PermanentAddresses returns all the permanent addresses. + PermanentAddresses() []tcpip.AddressWithPrefix +} + +// NDPEndpoint is a network endpoint that supports NDP. +type NDPEndpoint interface { + NetworkEndpoint + + // InvalidateDefaultRouter invalidates a default router discovered through + // NDP. + InvalidateDefaultRouter(tcpip.Address) +} + +// NetworkInterface is a network interface. +type NetworkInterface interface { + // ID returns the interface's ID. + ID() tcpip.NICID + + // IsLoopback returns true if the interface is a loopback interface. + IsLoopback() bool + + // Name returns the name of the interface. + // + // May return an empty string if the interface is not configured with a name. + Name() string + + // Enabled returns true if the interface is enabled. + Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { + AddressableEndpoint + + // Enable enables the endpoint. + // + // Must only be called when the stack is in a state that allows the endpoint + // to send and receive packets. + // + // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() *tcpip.Error + + // Enabled returns true if the endpoint is enabled. + Enabled() bool + + // Disable disables the endpoint. + Disable() + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) // for this endpoint. DefaultTTL() uint8 @@ -234,10 +521,6 @@ type NetworkEndpoint interface { // minus the network endpoint max header length. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // underlying link-layer endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the network (and lower // level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -245,8 +528,8 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have already - // been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and @@ -258,15 +541,6 @@ type NetworkEndpoint interface { // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error - // ID returns the network protocol endpoint ID. - ID() *NetworkEndpointID - - // PrefixLen returns the network endpoint's subnet prefix length in bits. - PrefixLen() int - - // NICID returns the id of the NIC this endpoint belongs to. - NICID() tcpip.NICID - // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // @@ -281,6 +555,17 @@ type NetworkEndpoint interface { NetworkProtocolNumber() tcpip.NetworkProtocolNumber } +// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. +type ForwardingNetworkProtocol interface { + NetworkProtocol + + // Forwarding returns the forwarding configuration. + Forwarding() bool + + // SetForwarding sets the forwarding configuration. + SetForwarding(bool) +} + // NetworkProtocol is the interface that needs to be implemented by network // protocols (e.g., ipv4, ipv6) that want to be part of the networking stack. type NetworkProtocol interface { @@ -300,17 +585,17 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -329,8 +614,7 @@ type NetworkProtocol interface { } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -341,6 +625,16 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -420,8 +714,8 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // - // Attach will be called with a nil dispatcher if the receiver's associated - // NIC is being removed. + // Attach is called with a nil dispatcher when the endpoint's NIC is being + // removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the @@ -436,6 +730,15 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -456,12 +759,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. - // The request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts + // the request on the local network if remoteLinkAddr is the zero value. The + // request is sent on linkEP with localAddr as the source. // // A valid response will cause the discovery protocol's network // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -471,7 +775,7 @@ type LinkAddressResolver interface { ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) // LinkAddressProtocol returns the network protocol of the - // addresses this this resolver can resolve. + // addresses this resolver can resolve. LinkAddressProtocol() tcpip.NetworkProtocolNumber } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index d65f8049e..cc39c9a6a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -42,17 +42,27 @@ type Route struct { // NetProto is the network-layer protocol. NetProto tcpip.NetworkProtocolNumber - // ref a reference to the network endpoint through which the route - // starts. - ref *referencedNetworkEndpoint - // Loop controls where WritePacket should send packets. Loop PacketLooping + + // nic is the NIC the route goes through. + nic *NIC + + // addressEndpoint is the local address this route is associated with. + addressEndpoint AssignableAddressEndpoint + + // linkCache is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkCache LinkAddressCache + + // linkRes is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkRes LinkAddressResolver } // makeRoute initializes a new route. It takes ownership of the provided -// reference to a network endpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route { +// AssignableAddressEndpoint. +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { loop := PacketOut if handleLocal && localAddr != "" && remoteAddr == localAddr { loop = PacketLoop @@ -62,29 +72,40 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - return Route{ + linkEP := nic.LinkEndpoint() + r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: localLinkAddr, + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, - ref: ref, + addressEndpoint: addressEndpoint, + nic: nic, Loop: loop, } + + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + r.linkRes = linkRes + r.linkCache = nic.stack + } + } + + return r } // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.ref.ep.NICID() + return r.nic.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.ref.ep.MaxHeaderLength() + return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.ref.nic.stack.Stats() + return r.nic.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -95,17 +116,23 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.ref.ep.Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.ref.ep.(GSOEndpoint); ok { + if gso, ok := r.nic.getNetworkEndpoint(r.NetProto).(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -131,7 +158,17 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { } nextAddr = r.RemoteAddress } - linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + + if neigh := r.nic.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + if err != nil { + return ch, err + } + r.RemoteLinkAddress = entry.LinkAddr + return nil, nil + } + + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { return ch, err } @@ -145,7 +182,13 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { if nextAddr == "" { nextAddr = r.RemoteAddress } - r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) + + if neigh := r.nic.neigh; neigh != nil { + neigh.removeWaker(nextAddr, waker) + return + } + + r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) } // IsResolutionRequired returns true if Resolve() must be called to resolve @@ -153,102 +196,88 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // // The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { - return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" + if r.nic.neigh != nil { + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + } + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + numBytes := pkt.Size() - err := r.ref.ep.WritePacket(r, gso, params, pkt) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil { + return err } - return err + + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return 0, tcpip.ErrInvalidEndpointState } - // WritePackets takes ownership of pkt, calculate length first. - numPkts := pkts.Len() - - n, err := r.ref.ep.WritePackets(r, gso, pkts, params) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) - } - r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - + n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Header.UsedLength() - writtenBytes += pb.Data.Size() + writtenBytes += pb.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Data.Size() - if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil { return err } - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.ref.ep.DefaultTTL() + return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.ref.ep.MTU() -} - -// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying -// network endpoint. -func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.ref.ep.NetworkProtocolNumber() + return r.nic.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.ref != nil { - r.ref.decRef() - r.ref = nil + if r.addressEndpoint != nil { + r.addressEndpoint.DecRef() + r.addressEndpoint = nil } } -// Clone Clone a route such that the original one can be released and the new -// one will remain valid. +// Clone clones the route. func (r *Route) Clone() Route { - if r.ref != nil { - r.ref.incRef() + if r.addressEndpoint != nil { + _ = r.addressEndpoint.IncRef() } return *r } @@ -272,7 +301,30 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.ref.stack() + return r.nic.stack +} + +func (r *Route) isV4Broadcast(addr tcpip.Address) bool { + if addr == header.IPv4Broadcast { + return true + } + + subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + return subnet.IsBroadcast(addr) +} + +// IsOutboundBroadcast returns true if the route is for an outbound broadcast +// packet. +func (r *Route) IsOutboundBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.isV4Broadcast(r.RemoteAddress) +} + +// IsInboundBroadcast returns true if the route is for an inbound broadcast +// packet. +func (r *Route) IsInboundBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.isV4Broadcast(r.LocalAddress) } // ReverseRoute returns new route with given source and destination address. @@ -283,7 +335,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { LocalLinkAddress: r.RemoteLinkAddress, RemoteAddress: src, RemoteLinkAddress: r.LocalLinkAddress, - ref: r.ref, Loop: r.Loop, + addressEndpoint: r.addressEndpoint, + nic: r.nic, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cdcfb8321..0bf20c0e1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -73,6 +73,16 @@ type TCPCubicState struct { WEst float64 } +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +type TCPRACKState struct { + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool +} + // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. type TCPEndpointID struct { // LocalPort is the local port associated with the endpoint. @@ -134,10 +144,7 @@ type TCPReceiverState struct { // PendingBufUsed is the number of bytes pending in the receive // queue. - PendingBufUsed seqnum.Size - - // PendingBufSize is the size of the socket receive buffer. - PendingBufSize seqnum.Size + PendingBufUsed int } // TCPSenderState holds a copy of the internal state of the sender for @@ -212,6 +219,9 @@ type TCPSenderState struct { // Cubic holds the state related to CUBIC congestion control. Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. @@ -235,7 +245,7 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to userspace since + // CopiedBytes is the number of bytes copied to user space since // this measure began. CopiedBytes int @@ -353,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } -// NICNameFromID is a function that returns a stable name for the specified NIC, -// even if different NIC IDs are used to refer to the same NIC in different -// program runs. It is used when generating opaque interface identifiers (IIDs). -// If the NIC was created with a name, it will be passed to NICNameFromID. -// -// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. -type NICNameFromID func(tcpip.NICID, string) string - -// OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. -type OpaqueInterfaceIdentifierOptions struct { - // NICNameFromID is a function that returns a stable name for a specified NIC, - // even if the NIC ID changes over time. - // - // Must be specified to generate the opaque IID. - NICNameFromID NICNameFromID - - // SecretKey is a pseudo-random number used as the secret key when generating - // opaque IIDs as defined by RFC 7217. The key SHOULD be at least - // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness - // requirements for security as outlined by RFC 4086. SecretKey MUST NOT - // change between program runs, unless explicitly changed. - // - // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey - // MUST NOT be modified after Stack is created. - // - // May be nil, but a nil value is highly discouraged to maintain - // some level of randomness between nodes. - SecretKey []byte -} - // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -402,10 +380,12 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool - cleanupEndpoints map[TransportEndpoint]struct{} + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + + // cleanupEndpointsMu protects cleanupEndpoints. + cleanupEndpointsMu sync.Mutex + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -416,7 +396,7 @@ type Stack struct { // If not nil, then any new endpoints will have this probe function // invoked everytime they receive a TCP segment. - tcpProbeFunc TCPProbeFunc + tcpProbeFunc atomic.Value // TCPProbeFunc // clock is used to generate user-visible times. clock tcpip.Clock @@ -425,6 +405,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -441,29 +422,20 @@ type Stack struct { // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 - // ndpConfigs is the default NDP configurations used by interfaces. - ndpConfigs NDPConfigurations + // nudConfigs is the default NUD configurations used by interfaces. + nudConfigs NUDConfigurations - // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool + // useNeighborCache indicates whether ARP and NDP packets should be handled + // by the NIC's neighborCache instead of linkAddrCache. + useNeighborCache bool - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher + // nudDisp is the NUD event dispatcher that is used to send the netstack + // integrator NUD related events. + nudDisp NUDDispatcher // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -486,13 +458,25 @@ type UniqueID interface { UniqueID() uint64 } +// NetworkProtocolFactory instantiates a network protocol. +// +// NetworkProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type NetworkProtocolFactory func(*Stack) NetworkProtocol + +// TransportProtocolFactory instantiates a transport protocol. +// +// TransportProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type TransportProtocolFactory func(*Stack) TransportProtocol + // Options contains optional Stack configuration. type Options struct { // NetworkProtocols lists the network protocols to enable. - NetworkProtocols []NetworkProtocol + NetworkProtocols []NetworkProtocolFactory // TransportProtocols lists the transport protocols to enable. - TransportProtocols []TransportProtocol + TransportProtocols []TransportProtocolFactory // Clock is an optional clock source used for timestampping packets. // @@ -510,60 +494,30 @@ type Options struct { // UniqueID is an optional generator of unique identifiers. UniqueID UniqueID - // NDPConfigs is the default NDP configurations used by interfaces. - // - // By default, NDPConfigs will have a zero value for its - // DupAddrDetectTransmits field, implying that DAD will not be performed - // before assigning an address to a NIC. - NDPConfigs NDPConfigurations - - // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. - // - // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address Detection - // is enabled, an address will only be assigned if it successfully resolves. - // If it fails, no further attempt will be made to auto-generate an IPv6 - // link-local address. - // - // The generated link-local address will follow RFC 4291 Appendix A - // guidelines. - AutoGenIPv6LinkLocal bool + // NUDConfigs is the default NUD configurations used by interfaces. + NUDConfigs NUDConfigurations + + // UseNeighborCache indicates whether ARP and NDP packets should be handled + // by the Neighbor Unreachability Detection (NUD) state machine. This flag + // also enables the APIs for inspecting and modifying the neighbor table via + // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor, + // and ClearNeighbors. + UseNeighborCache bool - // NDPDisp is the NDP event dispatcher that an integrator can provide to - // receive NDP related events. - NDPDisp NDPDispatcher + // NUDDisp is the NUD event dispatcher that an integrator can provide to + // receive NUD related events. + NUDDisp NUDDispatcher // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface - // identifiers (IIDs) as outlined by RFC 7217. - OpaqueIIDOpts OpaqueInterfaceIdentifierOptions - // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by rand.Read(). // // RandSource must be thread-safe. RandSource mathrand.Source - - // TempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable - // and random from the perspective of other nodes on the network. It is - // recommended that the seed be a random byte buffer of at least - // header.IIDSize bytes to make sure that temporary SLAAC addresses are - // sufficiently random. It should follow minimum randomness requirements for - // security as outlined by RFC 4086. - // - // Note: using a nil value, the same seed across netstack program runs, or a - // seed that is too small would reduce randomness and increase predictability, - // defeating the purpose of temporary SLAAC addresses. - TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -666,31 +620,28 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } - // Make sure opts.NDPConfigs contains valid values only. - opts.NDPConfigs.validate() + opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: DefaultTables(), - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - ndpConfigs: opts.NDPConfigs, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, - uniqueIDGenerator: opts.UniqueID, - ndpDisp: opts.NDPDisp, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - forwarder: newForwardQueue(), - randomGenerator: mathrand.New(randSrc), + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: DefaultTables(), + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + useNeighborCache: opts.UseNeighborCache, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -704,7 +655,8 @@ func New(opts Options) *Stack { } // Add specified network protocols. - for _, netProto := range opts.NetworkProtocols { + for _, netProtoFactory := range opts.NetworkProtocols { + netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -712,7 +664,8 @@ func New(opts Options) *Stack { } // Add specified transport protocols. - for _, transProto := range opts.TransportProtocols { + for _, transProtoFactory := range opts.TransportProtocols { + transProto := transProtoFactory(s) s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } @@ -727,6 +680,11 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + // UniqueID returns a unique identifier. func (s *Stack) UniqueID() uint64 { return s.uniqueIDGenerator.UniqueID() @@ -736,7 +694,7 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { return tcpip.ErrUnknownProtocol @@ -753,7 +711,7 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { return tcpip.ErrUnknownProtocol @@ -765,7 +723,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -780,7 +738,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -800,9 +758,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. @@ -813,46 +772,37 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs for the +// passed protocol. +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return tcpip.ErrUnknownProtocol + } - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { - return + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return tcpip.ErrNotSupported } - s.forwarding = enable + forwardingProtocol.SetForwarding(enable) + return nil +} - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { - return +// Forwarding returns true if packet forwarding between NICs is enabled for the +// passed protocol. +func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return false } - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() - } + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return false } -} -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding + return forwardingProtocol.Forwarding() } // SetRouteTable assigns the route table to be used by this stack. It @@ -887,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewEndpoint(s, network, waiterQueue) + return t.proto.NewEndpoint(network, waiterQueue) } // NewRawEndpoint creates a new raw transport layer endpoint of the given @@ -907,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewRawEndpoint(s, network, waiterQueue) + return t.proto.NewRawEndpoint(network, waiterQueue) } // NewPacketEndpoint creates a new packet endpoint listening for the given @@ -1014,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - return nic.disable() + nic.disable() + return nil } // CheckNIC checks if a NIC is usable. @@ -1027,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } - return nic.enabled() + return nic.Enabled() } // RemoveNIC removes NIC and all related routes from the network stack. @@ -1064,19 +1015,6 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { return nic.remove() } -// NICAddressRanges returns a map of NICIDs to their associated subnets. -func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { - s.mu.RLock() - defer s.mu.RUnlock() - - nics := map[tcpip.NICID][]tcpip.Subnet{} - - for id, nic := range s.nics { - nics[id] = append(nics[id], nic.AddressRanges()...) - } - return nics -} - // NICInfo captures the name and addresses assigned to a NIC. type NICInfo struct { Name string @@ -1094,6 +1032,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1113,18 +1056,19 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.enabled(), + Running: nic.Enabled(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.isLoopback(), + Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.PrimaryAddresses(), + ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -1178,41 +1122,12 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocolAddress, peb) -} - -// AddAddressRange adds a range of addresses to the specified NIC. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.AddAddressRange(protocol, subnet) - return nil - } - - return tcpip.ErrUnknownNICID -} - -// RemoveAddressRange removes the range of addresses from the specified NIC. -func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.RemoveAddressRange(subnet) - return nil - } - - return tcpip.ErrUnknownNICID + return nic.addAddress(protocolAddress, peb) } // RemoveAddress removes an existing network-layer address from the specified @@ -1222,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + return nic.removeAddress(addr) } return tcpip.ErrUnknownNICID @@ -1236,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) for id, nic := range s.nics { - nics[id] = nic.AllAddresses() + nics[id] = nic.allPermanentAddresses() } return nics } @@ -1258,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { if len(localAddr) == 0 { return nic.primaryEndpoint(netProto, remoteAddr) } @@ -1271,13 +1186,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isBroadcast := remoteAddr == header.IPv4Broadcast + isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil + if nic, ok := s.nics[id]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil } } } else { @@ -1285,18 +1200,23 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { + if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. - remoteAddr = ref.ep.ID().LocalAddress + remoteAddr = addressEndpoint.AddressWithPrefix().Address } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - if needRoute { - r.NextHop = route.Gateway + r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) + if len(route.Gateway) > 0 { + if needRoute { + r.NextHop = route.Gateway + } + } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } + return r, nil } } @@ -1326,26 +1246,25 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto // If a NIC is specified, we try to find the address there only. if nicID != 0 { - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return 0 } - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref == nil { + addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if addressEndpoint == nil { return 0 } - ref.decRef() + addressEndpoint.DecRef() return nic.id } // Go through all the NICs. for _, nic := range s.nics { - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref != nil { - ref.decRef() + if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() return nic.id } } @@ -1358,8 +1277,8 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return tcpip.ErrUnknownNICID } @@ -1374,8 +1293,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return tcpip.ErrUnknownNICID } @@ -1407,8 +1326,33 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } -// RemoveWaker implements LinkAddressCache.RemoveWaker. +// Neighbors returns all IP to MAC address associations. +func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return nil, tcpip.ErrUnknownNICID + } + + return nic.neighbors() +} + +// RemoveWaker removes a waker that has been added when link resolution for +// addr was requested. func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + if s.useNeighborCache { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if ok { + nic.removeWaker(addr, waker) + } + return + } + s.mu.RLock() defer s.mu.RUnlock() @@ -1418,6 +1362,47 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. } } +// AddStaticNeighbor statically associates an IP address to a MAC address. +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.addStaticNeighbor(addr, linkAddr) +} + +// RemoveNeighbor removes an IP to MAC address association previously created +// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there +// is no association with the provided address. +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.removeNeighbor(addr) +} + +// ClearNeighbors removes all IP to MAC address associations. +func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.clearNeighbors() +} + // RegisterTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but @@ -1441,10 +1426,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { - s.mu.Lock() - defer s.mu.Unlock() - + s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} + s.cleanupEndpointsMu.Unlock() s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } @@ -1452,9 +1436,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp // CompleteTransportEndpointCleanup removes the endpoint from the cleanup // stage. func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() delete(s.cleanupEndpoints, ep) - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // FindTransportEndpoint finds an endpoint that most closely matches the provided @@ -1497,23 +1481,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint { // CleanupEndpoints returns endpoints currently in the cleanup state. func (s *Stack) CleanupEndpoints() []TransportEndpoint { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) for e := range s.cleanupEndpoints { es = append(es, e) } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() return es } // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() for _, e := range es { s.cleanupEndpoints[e] = struct{}{} } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // Close closes all currently registered transport endpoints. @@ -1708,18 +1692,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra // guarantee provided on which probe will be invoked. Ideally this should only // be called once per stack. func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { - s.mu.Lock() - s.tcpProbeFunc = probe - s.mu.Unlock() + s.tcpProbeFunc.Store(probe) } // GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil // otherwise. func (s *Stack) GetTCPProbe() TCPProbeFunc { - s.mu.Lock() - p := s.tcpProbeFunc - s.mu.Unlock() - return p + p := s.tcpProbeFunc.Load() + if p == nil { + return nil + } + return p.(TCPProbeFunc) } // RemoveTCPProbe removes an installed TCP probe. @@ -1728,9 +1711,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc { // have a probe attached. Endpoints already created will continue to invoke // TCP probe. func (s *Stack) RemoveTCPProbe() { - s.mu.Lock() - s.tcpProbeFunc = nil - s.mu.Unlock() + // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics. + s.tcpProbeFunc.Store(TCPProbeFunc(nil)) } // JoinGroup joins the given multicast group on the given NIC. @@ -1751,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { - return nic.leaveGroup(multicastAddr) + return nic.leaveGroup(protocol, multicastAddr) } return tcpip.ErrUnknownNICID } @@ -1803,70 +1785,47 @@ func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } -// IsAddrTentative returns true if addr is tentative on the NIC with ID id. -// -// Note that if addr is not associated with a NIC with id ID, then this -// function will return false. It will only return true if the address is -// associated with the NIC AND it is tentative. -func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() +// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol +// number installed on the specified NIC. +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { + s.mu.Lock() + defer s.mu.Unlock() - nic, ok := s.nics[id] + nic, ok := s.nics[nicID] if !ok { - return false, tcpip.ErrUnknownNICID + return nil, tcpip.ErrUnknownNICID } - return nic.isAddrTentative(addr), nil + return nic.getNetworkEndpoint(proto), nil } -// DupTentativeAddrDetected attempts to inform the NIC with ID id that a -// tentative addr on it is a duplicate on a link. -func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - +// NUDConfigurations gets the per-interface NUD configurations. +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { + s.mu.RLock() nic, ok := s.nics[id] + s.mu.RUnlock() + if !ok { - return tcpip.ErrUnknownNICID + return NUDConfigurations{}, tcpip.ErrUnknownNICID } - return nic.dupTentativeAddrDetected(addr) + return nic.nudConfigs() } -// SetNDPConfigurations sets the per-interface NDP configurations on the NIC -// with ID id to c. +// SetNUDConfigurations sets the per-interface NUD configurations. // -// Note, if c contains invalid NDP configuration values, it will be fixed to +// Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { + s.mu.RLock() nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - nic.setNDPConfigs(c) - - return nil -} - -// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement -// message that it needs to handle. -func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RUnlock() - nic, ok := s.nics[id] if !ok { return tcpip.ErrUnknownNICID } - nic.handleNDPRA(ip, ra) - - return nil + return nic.setNUDConfigs(c) } // Seed returns a 32 bit value that can be used as a seed value for port @@ -1906,28 +1865,24 @@ func generateRandInt64() int64 { // FindNetworkEndpoint returns the network endpoint for the given address. func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() for _, nic := range s.nics { - id := NetworkEndpointID{address} - - if ref, ok := nic.mu.endpoints[id]; ok { - nic.mu.RLock() - defer nic.mu.RUnlock() - - // An endpoint with this id exists, check if it can be - // used and return it. - return ref.ep, nil + addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + continue } + addressEndpoint.DecRef() + return nic.getNetworkEndpoint(netProto), nil } return nil, tcpip.ErrBadAddress } -// FindNICNameFromID returns the name of the nic for the given NICID. +// FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { @@ -1936,3 +1891,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { return nic.Name() } + +// NewJob returns a new tcpip.Job using the stack's clock. +func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7657a4101..aa20f750b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,18 +21,21 @@ import ( "bytes" "fmt" "math" + "net" "sort" - "strings" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -66,40 +69,53 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { + stack.AddressableEndpointState + + mu struct { + sync.RWMutex + + enabled bool + } + nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint } -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = true + return nil } -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (f *fakeNetworkEndpoint) Enabled() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.enabled } -func (f *fakeNetworkEndpoint) PrefixLen() int { - return f.prefixLen +func (f *fakeNetworkEndpoint) Disable() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = false } -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id +func (*fakeNetworkEndpoint) DefaultTTL() uint8 { + return 123 } func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ + f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -115,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -126,10 +142,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto return 0 } -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -140,10 +152,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) - pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] - pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] - pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = r.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { f.HandlePacket(r, pkt) @@ -164,16 +176,8 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack return tcpip.ErrNotSupported } -func (*fakeNetworkEndpoint) Close() {} - -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool +func (f *fakeNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() } // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the @@ -182,7 +186,12 @@ type fakeNetOptions struct { type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int - opts fakeNetOptions + defaultTTL uint8 + + mu struct { + sync.RWMutex + forwarding bool + } } func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -205,57 +214,67 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - return &fakeNetworkEndpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &fakeNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, - }, nil + ep: nic.LinkEndpoint(), + } + e.AddressableEndpointState.Init(e) + return e } -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) + case *tcpip.DefaultTTLOption: + f.defaultTTL = uint8(*v) return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: return tcpip.ErrUnknownProtocolOption } } -// Close implements TransportProtocol.Close. +// Close implements NetworkProtocol.Close. func (*fakeNetworkProtocol) Close() {} -// Wait implements TransportProtocol.Wait. +// Wait implements NetworkProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} -// Parse implements TransportProtocol.Parse. +// Parse implements NetworkProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(fakeNetHeaderLen) return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -func fakeNetFactory() stack.NetworkProtocol { +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + +func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -276,12 +295,23 @@ func (l *linkEPWithMockedAttach) isAttached() bool { return l.attached } +// Checks to see if list contains an address. +func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { + for _, i := range list { + if i == item { + return true + } + } + + return false +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -301,9 +331,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -313,9 +343,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -325,9 +355,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -336,9 +366,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -348,9 +378,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -369,11 +399,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -428,9 +457,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -442,7 +471,7 @@ func TestNetworkSend(t *testing.T) { // existing nic. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) @@ -469,7 +498,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -569,7 +598,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -588,7 +617,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { @@ -600,7 +629,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := loopback.New() @@ -647,7 +676,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { @@ -659,7 +688,7 @@ func TestRemoveNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -720,7 +749,7 @@ func TestRouteWithDownNIC(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(1, defaultMTU, "") @@ -867,9 +896,9 @@ func TestRouteWithDownNIC(t *testing.T) { // Writes with Routes that use NIC1 after being brought up should // succeed. // - // TODO(b/147015577): Should we instead completely invalidate all - // Routes that were bound to a NIC that was brought down at some - // point? + // TODO(gvisor.dev/issue/1491): Should we instead completely + // invalidate all Routes that were bound to a NIC that was brought + // down at some point? if err := upFn(s, nicID1); err != nil { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } @@ -886,7 +915,7 @@ func TestRoutes(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -966,7 +995,7 @@ func TestAddressRemoval(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1013,7 +1042,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1104,7 +1133,7 @@ func TestEndpointExpiration(t *testing.T) { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1262,7 +1291,7 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1314,7 +1343,7 @@ func TestSpoofingWithAddress(t *testing.T) { dstAddr := tcpip.Address("\x03") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1380,7 +1409,7 @@ func TestSpoofingNoAddress(t *testing.T) { dstAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1443,7 +1472,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error { func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1486,7 +1515,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // Create a new stack with two NICs. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1587,7 +1616,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1642,239 +1671,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Add a range of addresses, then check that a packet is delivered. -func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { - t.Helper() - - // Loop over all addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - - addrBytes := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - addr := tcpip.Address(addrBytes) - wantNicID := nicID - // The subnet and broadcast addresses are skipped. - if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { - wantNicID = 0 - } - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) - } - addrBytes[0]++ - } - - // Trying the next address should always fail since it is outside the range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) - } -} - -// Set a range of addresses, then remove it again, and check at each step that -// CheckLocalAddress returns the correct NIC for each address or zero if not -// existent. -func TestCheckLocalAddressForSubnet(t *testing.T) { - const nicID tcpip.NICID = 1 +func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{}, }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) + opt := tcpip.DefaultTTLOption(5) + if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { + t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) } - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) - - if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) + var optGot tcpip.DefaultTTLOption + if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { + t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) } - testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) - - if err := s.RemoveAddressRange(nicID, subnet); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) -} - -// Set a range of addresses, then send a packet to a destination outside the -// range and then check it doesn't get delivered. -func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func TestNetworkOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, - }) - - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } - } -} - -func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { - ranges, ok := s.NICAddressRanges()[id] - if !ok { - return false - } - for _, r := range ranges { - if r == addrRange { - return true - } - } - return false -} - -func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - addrRange, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.RemoveAddressRange(1, addrRange); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) + if opt != optGot { + t.Errorf("got optGot = %d, want = %d", optGot, opt) } } @@ -1886,7 +1700,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1953,7 +1767,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -2038,7 +1852,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2065,7 +1879,7 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2099,7 +1913,7 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2130,7 +1944,7 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2251,7 +2065,7 @@ func TestCreateNICWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2271,9 +2085,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -2318,9 +2132,9 @@ func TestNICForwarding(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { @@ -2353,9 +2167,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) pkt, ok := ep2.Read() if !ok { @@ -2363,8 +2177,8 @@ func TestNICForwarding(t *testing.T) { } // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) } // Test that forwarding increments Tx stats correctly. @@ -2442,7 +2256,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName string autoGen bool linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions + iidOpts ipv6.OpaqueInterfaceIdentifierOptions shouldGen bool expectedAddr tcpip.Address }{ @@ -2458,7 +2272,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: false, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2503,7 +2317,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: true, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2515,7 +2329,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { { name: "OIID Empty MAC and empty nicName", autoGen: true, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:1], }, @@ -2527,7 +2341,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test", autoGen: true, linkAddr: "\x01\x02\x03", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:2], }, @@ -2539,7 +2353,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test2", autoGen: true, linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:3], }, @@ -2551,7 +2365,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test3", autoGen: true, linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, }, shouldGen: true, @@ -2565,10 +2379,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, + })}, } e := channel.New(0, 1280, test.linkAddr) @@ -2640,15 +2455,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { tests := []struct { name string - opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions }{ { name: "IID From MAC", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, }, { name: "Opaque IID", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName }, @@ -2659,9 +2474,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + })}, } e := loopback.New() @@ -2690,12 +2506,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - ndpConfigs := stack.DefaultNDPConfigurations() + ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + })}, } e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2751,7 +2568,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { for _, ps := range pebs { t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -3042,14 +2859,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3088,59 +2906,58 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { const nicID = 1 + broadcastAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) } - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } // Enabling the NIC should add the IPv4 broadcast address. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 1 { - t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) - } - want := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - if allNICAddrs[0] != want { - t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if !containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) + } } // Disabling the NIC should remove the IPv4 broadcast address. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } } @@ -3151,7 +2968,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -3188,50 +3005,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { } } -func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { +func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + name: "IPv6 All-Nodes", + proto: header.IPv6ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + }, + { + name: "IPv4 All-Systems", + proto: header.IPv4ProtocolNumber, + addr: header.IPv4AllSystems, + }, } - // Should not be in the IPv6 all-nodes multicast group yet because the NIC has - // not been enabled yet. - isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) - } + // Should not be in the multicast group yet because the NIC has not been + // enabled yet. + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } - // The all-nodes multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // The multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // Leaving the group before disabling the NIC should not cause an error. + if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + }) } } @@ -3246,12 +3106,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + })}, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3418,3 +3279,399 @@ func TestStackSendBufferSizeOption(t *testing.T) { }) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID1 = 1 + ) + + defaultAddr := tcpip.AddressWithPrefix{ + Address: header.IPv4Any, + PrefixLen: 0, + } + defaultSubnet := defaultAddr.Subnet() + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedRoute stack.Route + }{ + // Broadcast to a locally attached subnet populates the broadcast MAC. + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: ipv4SubnetBcast, + RemoteLinkAddress: header.EthernetBroadcastAddress, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /31 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix31.Address, + RemoteAddress: ipv4Subnet31Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /32 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix32.Address, + RemoteAddress: ipv4Subnet32Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv6Addr.Address, + RemoteAddress: ipv6SubnetBcast, + NetProto: header.IPv6ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a remote subnet in the route table is send to the next-hop + // gateway. + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to an unknown subnet follows the default route. Note that this + // is essentially just routing an unknown destination IP, because w/o any + // subnet prefix information a subnet broadcast address is just a normal IP. + { + name: "IPv4 Broadcast to unknown subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: defaultSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) + } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { + t.Errorf("route mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} + +// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its +// associated address is removed should not cause a panic. +func TestRouteReleaseAfterAddrRemoval(t *testing.T) { + const ( + nicID = 1 + localAddr = tcpip.Address("\x01") + remoteAddr = tcpip.Address("\x02") + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) + } + // Should not panic. + defer r.Release() + + // Check that removing the same address fails. + if err := s.RemoveAddress(nicID, localAddr); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) + } +} + +func TestGetNetworkEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) + for _, test := range tests { + factories = append(factories, test.protoFactory) + } + + s := stack.New(stack.Options{ + NetworkProtocols: factories, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) + } + + if got := ep.NetworkProtocolNumber(); got != test.protoNum { + t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) + } + }) + } +} + +func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + } + + // Check that we get the right initial address and prefix length. + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } + + // Should still get the address when the NIC is diabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..35e5b1a2e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return true } - // If the packet is a TCP packet with a non-unicast source or destination - // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // If the packet is a TCP packet with a unspecified source or non-unicast + // destination address, then do nothing further and instruct the caller to do + // the same. The network layer handles address validation for specified source + // addresses. + if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -677,10 +679,10 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isSpecified(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 73dada928..698c8609e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -51,8 +51,8 @@ type testContext struct { // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { @@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) } func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { @@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { @@ -184,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) @@ -314,8 +312,8 @@ func TestBindToDeviceDistribution(t *testing.T) { t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) } var dstAddr tcpip.Address diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 7e8b84867..62ab6d92f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -39,7 +39,7 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo - stack *stack.Stack + proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -53,14 +53,14 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { return &f.TransportEndpointInfo } -func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { +func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Abort() { @@ -84,28 +84,28 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) - hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } @@ -130,11 +130,7 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } @@ -147,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) + r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return tcpip.ErrNoRoute } @@ -155,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -184,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -194,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip. } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( + if err := f.proto.stack.RegisterTransportEndpoint( a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, @@ -222,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -239,19 +234,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } -func (f *fakeTransportEndpoint) State() uint32 { +func (*fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} +func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { - return stack.IPTables{}, nil -} +func (*fakeTransportEndpoint) Resume(*stack.Stack) {} -func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} +func (*fakeTransportEndpoint) Wait() {} -func (f *fakeTransportEndpoint) Wait() {} +func (*fakeTransportEndpoint) LastError() *tcpip.Error { + return nil +} type fakeTransportGoodOption bool @@ -266,6 +261,8 @@ type fakeTransportProtocolOptions struct { // fakeTransportProtocol is a transport-layer protocol descriptor. It // aggregates the number of packets received via endpoints of this protocol. type fakeTransportProtocol struct { + stack *stack.Stack + packetCount int controlCount int opts fakeTransportProtocolOptions @@ -275,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil } -func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -291,26 +288,24 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) + case *tcpip.TCPModerateReceiveBufferOption: + f.opts.good = bool(*v) return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) + case *tcpip.TCPModerateReceiveBufferOption: + *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: return tcpip.ErrUnknownProtocolOption @@ -328,24 +323,19 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) - if !ok { - return false - } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(fakeTransHeaderLen) - return true + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok } -func fakeTransFactory() stack.TransportProtocol { - return &fakeTransportProtocol{} +func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { + return &fakeTransportProtocol{stack: s} } func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -382,9 +372,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -393,9 +383,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -404,9 +394,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -415,8 +405,8 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -459,9 +449,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -470,9 +460,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -481,9 +471,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -492,8 +482,8 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -538,54 +528,29 @@ func TestTransportSend(t *testing.T) { func TestTransportOptions(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } + v := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) + } + v = false + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) + } + if !v { + t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } func TestTransportForwarding(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() @@ -636,11 +601,11 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: req.ToVectorisedView(), - }) + })) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err != nil || aep == nil { t.Fatalf("Accept failed: %v, %v", aep, err) } @@ -655,10 +620,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.NetworkHeader[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.NetworkHeader[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2be1c107a..c42bb0991 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -43,6 +43,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Using header.IPv4AddressSize would cause an import cycle. +const ipv4AddressSize = 4 + // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // @@ -192,7 +195,7 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } -// A Clock provides the current time. +// A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. @@ -203,12 +206,45 @@ type Clock interface { // NowMonotonic returns a monotonic time value. NowMonotonic() int64 + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) } // Address is a byte slice cast as a string that represents the address of a // network node. Or, in the case of unix endpoints, it may represent a path. type Address string +// WithPrefix returns the address with a prefix that represents a point subnet. +func (a Address) WithPrefix() AddressWithPrefix { + return AddressWithPrefix{ + Address: a, + PrefixLen: len(a) * 8, + } +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -295,6 +331,29 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// IsBroadcast returns true if the address is considered a broadcast address. +func (s *Subnet) IsBroadcast(address Address) bool { + // Only IPv4 supports the notion of a broadcast address. + if len(address) != ipv4AddressSize { + return false + } + + // Normally, we would just compare address with the subnet's broadcast + // address but there is an exception where a simple comparison is not + // correct. This exception is for /31 and /32 IPv4 subnets where all + // addresses are considered valid host addresses. + // + // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that + // both addresses in a /31 subnet "MUST be interpreted as host addresses." + // + // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32 + // subnets. However, the same reasoning applies - if an exception is not + // made, then there do not exist any host addresses in a /32 subnet. RFC + // 4632 Section 3.1 also vaguely implies this interpretation by referring + // to addresses in /32 subnets as "host routes." + return s.Prefix() <= 30 && s.Broadcast() == address +} + // Equal returns true if s equals o. // // Needed to use cmp.Equal on Subnet as its fields are unexported. @@ -316,6 +375,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -488,7 +569,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *waiter.Queue, *Error) + // + // If peerAddr is not nil then it is populated with the peer address of the + // returned endpoint. + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -505,8 +589,8 @@ type Endpoint interface { // if waiter.EventIn is set, the endpoint is immediately readable. Readiness(mask waiter.EventMask) waiter.EventMask - // SetSockOpt sets a socket option. opt should be one of the *Option types. - SetSockOpt(opt interface{}) *Error + // SetSockOpt sets a socket option. + SetSockOpt(opt SettableSocketOption) *Error // SetSockOptBool sets a socket option, for simple cases where a value // has the bool type. @@ -516,9 +600,8 @@ type Endpoint interface { // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error - // GetSockOpt gets a socket option. opt should be a pointer to one of the - // *Option types. - GetSockOpt(opt interface{}) *Error + // GetSockOpt gets a socket option. + GetSockOpt(opt GettableSocketOption) *Error // GetSockOptBool gets a socket option for simple cases where a return // value has the bool type. @@ -547,6 +630,31 @@ type Endpoint interface { // SetOwner sets the task owner to the endpoint owner. SetOwner(owner PacketOwner) + + // LastError clears and returns the last error reported by the endpoint. + LastError() *Error +} + +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) } // EndpointInfo is the interface implemented by each endpoint info struct. @@ -648,6 +756,11 @@ const ( // whether an IPv6 socket is to be restricted to sending and receiving // IPv6 packets only. V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. @@ -673,6 +786,13 @@ const ( // TCP_MAXSEG option. MaxSegOption + // MTUDiscoverOption is used to set/get the path MTU discovery setting. + // + // NOTE: Setting this option to any other value than PMTUDiscoveryDont + // is not supported and will fail as such, and getting this option will + // always return PMTUDiscoveryDont. + MTUDiscoverOption + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control // the default TTL value for multicast messages. The default is 1. MulticastTTLOption @@ -714,14 +834,152 @@ const ( TCPWindowClampOption ) -// ErrorOption is used in GetSockOpt to specify that the last error reported by -// the endpoint should be cleared and returned. -type ErrorOption struct{} +const ( + // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use + // per-route settings. + PMTUDiscoveryWant int = iota + + // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable + // path MTU discovery. + PMTUDiscoveryDont + + // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do + // path MTU discovery. + PMTUDiscoveryDo + + // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF + // but ignore path MTU. + PMTUDiscoveryProbe +) + +// GettableNetworkProtocolOption is a marker interface for network protocol +// options that may be queried. +type GettableNetworkProtocolOption interface { + isGettableNetworkProtocolOption() +} + +// SettableNetworkProtocolOption is a marker interface for network protocol +// options that may be set. +type SettableNetworkProtocolOption interface { + isSettableNetworkProtocolOption() +} + +// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify +// a default TTL. +type DefaultTTLOption uint8 + +func (*DefaultTTLOption) isGettableNetworkProtocolOption() {} + +func (*DefaultTTLOption) isSettableNetworkProtocolOption() {} + +// GettableTransportProtocolOption is a marker interface for transport protocol +// options that may be queried. +type GettableTransportProtocolOption interface { + isGettableTransportProtocolOption() +} + +// SettableTransportProtocolOption is a marker interface for transport protocol +// options that may be set. +type SettableTransportProtocolOption interface { + isSettableTransportProtocolOption() +} + +// TCPSACKEnabled the SACK option for TCP. +// +// See: https://tools.ietf.org/html/rfc2018. +type TCPSACKEnabled bool + +func (*TCPSACKEnabled) isGettableTransportProtocolOption() {} + +func (*TCPSACKEnabled) isSettableTransportProtocolOption() {} + +// TCPRecovery is the loss deteoction algorithm used by TCP. +type TCPRecovery int32 + +func (*TCPRecovery) isGettableTransportProtocolOption() {} + +func (*TCPRecovery) isSettableTransportProtocolOption() {} + +const ( + // TCPRACKLossDetection indicates RACK is used for loss detection and + // recovery. + TCPRACKLossDetection TCPRecovery = 1 << iota + + // TCPRACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + TCPRACKStaticReoWnd + + // TCPRACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + TCPRACKNoDupTh +) + +// TCPDelayEnabled enables/disables Nagle's algorithm in TCP. +type TCPDelayEnabled bool + +func (*TCPDelayEnabled) isGettableTransportProtocolOption() {} + +func (*TCPDelayEnabled) isSettableTransportProtocolOption() {} + +// TCPSendBufferSizeRangeOption is the send buffer size range for TCP. +type TCPSendBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPSendBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPSendBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPReceiveBufferSizeRangeOption is the receive buffer size range for TCP. +type TCPReceiveBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPReceiveBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPReceiveBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPAvailableCongestionControlOption is the supported congestion control +// algorithms for TCP +type TCPAvailableCongestionControlOption string + +func (*TCPAvailableCongestionControlOption) isGettableTransportProtocolOption() {} + +func (*TCPAvailableCongestionControlOption) isSettableTransportProtocolOption() {} + +// TCPModerateReceiveBufferOption enables/disables receive buffer moderation +// for TCP. +type TCPModerateReceiveBufferOption bool + +func (*TCPModerateReceiveBufferOption) isGettableTransportProtocolOption() {} + +func (*TCPModerateReceiveBufferOption) isSettableTransportProtocolOption() {} + +// GettableSocketOption is a marker interface for socket options that may be +// queried. +type GettableSocketOption interface { + isGettableSocketOption() +} + +// SettableSocketOption is a marker interface for socket options that may be +// configured. +type SettableSocketOption interface { + isSettableSocketOption() +} // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID +func (*BindToDeviceOption) isGettableSocketOption() {} + +func (*BindToDeviceOption) isSettableSocketOption() {} + // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -730,68 +988,143 @@ type TCPInfoOption struct { RTTVar time.Duration } +func (*TCPInfoOption) isGettableSocketOption() {} + // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. type KeepaliveIdleOption time.Duration +func (*KeepaliveIdleOption) isGettableSocketOption() {} + +func (*KeepaliveIdleOption) isSettableSocketOption() {} + // KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration +func (*KeepaliveIntervalOption) isGettableSocketOption() {} + +func (*KeepaliveIntervalOption) isSettableSocketOption() {} + // TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user // specified timeout for a given TCP connection. // See: RFC5482 for details. type TCPUserTimeoutOption time.Duration +func (*TCPUserTimeoutOption) isGettableSocketOption() {} + +func (*TCPUserTimeoutOption) isSettableSocketOption() {} + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string -// AvailableCongestionControlOption is used to query the supported congestion -// control algorithms. -type AvailableCongestionControlOption string +func (*CongestionControlOption) isGettableSocketOption() {} + +func (*CongestionControlOption) isSettableSocketOption() {} -// buffer moderation. -type ModerateReceiveBufferOption bool +func (*CongestionControlOption) isGettableTransportProtocolOption() {} + +func (*CongestionControlOption) isSettableTransportProtocolOption() {} // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. type TCPLingerTimeoutOption time.Duration +func (*TCPLingerTimeoutOption) isGettableSocketOption() {} + +func (*TCPLingerTimeoutOption) isSettableSocketOption() {} + +func (*TCPLingerTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPLingerTimeoutOption) isSettableTransportProtocolOption() {} + // TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TIME_WAIT state // before being marked closed. type TCPTimeWaitTimeoutOption time.Duration +func (*TCPTimeWaitTimeoutOption) isGettableSocketOption() {} + +func (*TCPTimeWaitTimeoutOption) isSettableSocketOption() {} + +func (*TCPTimeWaitTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitTimeoutOption) isSettableTransportProtocolOption() {} + // TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a // accept to return a completed connection only when there is data to be // read. This usually means the listening socket will drop the final ACK // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration +func (*TCPDeferAcceptOption) isGettableSocketOption() {} + +func (*TCPDeferAcceptOption) isSettableSocketOption() {} + // TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding // default MinRTO used by the Stack. type TCPMinRTOOption time.Duration +func (*TCPMinRTOOption) isGettableSocketOption() {} + +func (*TCPMinRTOOption) isSettableSocketOption() {} + +func (*TCPMinRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMinRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding // default MaxRTO used by the Stack. type TCPMaxRTOOption time.Duration +func (*TCPMaxRTOOption) isGettableSocketOption() {} + +func (*TCPMaxRTOOption) isSettableSocketOption() {} + +func (*TCPMaxRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the // maximum number of retransmits after which we time out the connection. type TCPMaxRetriesOption uint64 +func (*TCPMaxRetriesOption) isGettableSocketOption() {} + +func (*TCPMaxRetriesOption) isSettableSocketOption() {} + +func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} + // TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify // the number of endpoints that can be in SYN-RCVD state before the stack // switches to using SYN cookies. type TCPSynRcvdCountThresholdOption uint64 +func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} + +func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} + +func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} + // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 +func (*TCPSynRetriesOption) isGettableSocketOption() {} + +func (*TCPSynRetriesOption) isSettableSocketOption() {} + +func (*TCPSynRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRetriesOption) isSettableTransportProtocolOption() {} + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -799,33 +1132,89 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to -// AddMembershipOption and RemoveMembershipOption. +func (*MulticastInterfaceOption) isGettableSocketOption() {} + +func (*MulticastInterfaceOption) isSettableSocketOption() {} + +// MembershipOption is used to identify a multicast membership on an interface. type MembershipOption struct { NIC NICID InterfaceAddr Address MulticastAddr Address } -// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast -// group identified by the given multicast address, on the interface matching -// the given interface address. +// AddMembershipOption identifies a multicast group to join on some interface. type AddMembershipOption MembershipOption -// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast -// group identified by the given multicast address, on the interface matching -// the given interface address. +func (*AddMembershipOption) isSettableSocketOption() {} + +// RemoveMembershipOption identifies a multicast group to leave on some +// interface. type RemoveMembershipOption MembershipOption +func (*RemoveMembershipOption) isSettableSocketOption() {} + // OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify -// a default TTL. -type DefaultTTLOption uint8 +func (*OutOfBandInlineOption) isGettableSocketOption() {} + +func (*OutOfBandInlineOption) isSettableSocketOption() {} + +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int + +func (*SocketDetachFilterOption) isSettableSocketOption() {} +// OriginalDestinationOption is used to get the original destination address +// and port of a redirected packet. +type OriginalDestinationOption FullAddress + +func (*OriginalDestinationOption) isGettableSocketOption() {} + +// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to +// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for +// new connections when it is safe from protocol viewpoint. +type TCPTimeWaitReuseOption uint8 + +func (*TCPTimeWaitReuseOption) isGettableSocketOption() {} + +func (*TCPTimeWaitReuseOption) isSettableSocketOption() {} + +func (*TCPTimeWaitReuseOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitReuseOption) isSettableTransportProtocolOption() {} + +const ( + // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot + // be reused for new connections. + TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota + + // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can + // be reused for new connections irrespective of the src/dest addresses. + TCPTimeWaitReuseGlobal + + // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can + // only be reused if the connection was a connection over loopback. i.e src/dest adddresses + // are loopback addresses. + TCPTimeWaitReuseLoopbackOnly +) + +// LingerOption is used by SetSockOpt/GetSockOpt to set/get the +// duration for which a socket lingers before returning from Close. // +// +stateify savable +type LingerOption struct { + Enabled bool + Timeout time.Duration +} + +func (*LingerOption) isGettableSocketOption() {} + +func (*LingerOption) isSettableSocketOption() {} + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -836,7 +1225,7 @@ type IPPacketInfo struct { // LocalAddr is the local address. LocalAddr Address - // DestinationAddr is the destination address. + // DestinationAddr is the destination address found in the IP header. DestinationAddr Address } @@ -868,7 +1257,10 @@ func (r Route) String() string { // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 -// NetworkProtocolNumber is the number of a network protocol. +// NetworkProtocolNumber is the EtherType of a network protocol in an Ethernet +// frame. +// +// See: https://www.iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. @@ -1031,6 +1423,10 @@ type ICMPv6ReceivedPacketStats struct { // Invalid is the total number of ICMPv6 packets received that the // transport layer could not parse. Invalid *StatCounter + + // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets + // dropped due to being router-specific packets. + RouterOnlyPacketsDroppedByHost *StatCounter } // ICMPStats collects ICMP-specific stats (both v4 and v6). @@ -1086,6 +1482,18 @@ type IPStats struct { // MalformedFragmentsReceived is the total number of IP Fragments that were // dropped due to the fragment failing validation checks. MalformedFragmentsReceived *StatCounter + + // IPTablesPreroutingDropped is the total number of IP packets dropped + // in the Prerouting chain. + IPTablesPreroutingDropped *StatCounter + + // IPTablesInputDropped is the total number of IP packets dropped in + // the Input chain. + IPTablesInputDropped *StatCounter + + // IPTablesOutputDropped is the total number of IP packets dropped in + // the Output chain. + IPTablesOutputDropped *StatCounter } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD new file mode 100644 index 000000000..06c7a3cd3 --- /dev/null +++ b/pkg/tcpip/tests/integration/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "integration_test", + size = "small", + srcs = [ + "loopback_test.go", + "multicast_broadcast_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go new file mode 100644 index 000000000..e8caf09ba --- /dev/null +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -0,0 +1,314 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) + +type ndpDispatcher struct{} + +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { + return false +} + +func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} + +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} + +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return true +} + +func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} + +func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {} + +func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {} + +func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {} + +func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {} + +// TestInitialLoopbackAddresses tests that the loopback interface does not +// auto-generate a link-local address when it is brought up. +func TestInitialLoopbackAddresses(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDispatcher{}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID, nicName string) string { + t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) + return "" + }, + }, + })}, + }) + + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + nicsInfo := s.NICInfo() + if nicInfo, ok := nicsInfo[nicID]; !ok { + t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo) + } else if got := len(nicInfo.ProtocolAddresses); got != 0 { + t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses) + } +} + +// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address. +func TestLoopbackAcceptAllInSubnet(t *testing.T) { + const ( + nicID = 1 + localPort = 80 + ) + + data := []byte{1, 2, 3, 4} + + ipv4ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + ipv4Bytes := []byte(ipv4Addr.Address) + ipv4Bytes[len(ipv4Bytes)-1]++ + otherIPv4Address := tcpip.Address(ipv4Bytes) + + ipv6ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + } + ipv6Bytes := []byte(ipv6Addr.Address) + ipv6Bytes[len(ipv6Bytes)-1]++ + otherIPv6Address := tcpip.Address(ipv6Bytes) + + tests := []struct { + name string + addAddress tcpip.ProtocolAddress + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 bind to wildcard and send to assigned address", + addAddress: ipv4ProtocolAddress, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 bind to wildcard and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + dstAddr: otherIPv4Address, + expectRx: true, + }, + { + name: "IPv4 bind to wildcard send to other address", + addAddress: ipv4ProtocolAddress, + dstAddr: remoteIPv4Addr, + expectRx: false, + }, + { + name: "IPv4 bind to other subnet-local address and send to assigned address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 bind and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: otherIPv4Address, + expectRx: true, + }, + { + name: "IPv4 bind to assigned address and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: ipv4Addr.Address, + dstAddr: otherIPv4Address, + expectRx: false, + }, + + { + name: "IPv6 bind and send to assigned address", + addAddress: ipv6ProtocolAddress, + bindAddr: ipv6Addr.Address, + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 bind to wildcard and send to other subnet-local address", + addAddress: ipv6ProtocolAddress, + dstAddr: otherIPv6Address, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + wq := waiter.Queue{} + rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer rep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := rep.Bind(bindAddr); err != nil { + t.Fatalf("rep.Bind(%+v): %s", bindAddr, err) + } + + sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer sep.Close() + + wopts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.dstAddr, + Port: localPort, + }, + } + n, _, err := sep.Write(tcpip.SlicePayload(data), wopts) + if err != nil { + t.Fatalf("sep.Write(_, _): %s", err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) + } + + if gotPayload, _, err := rep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("reep.Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} + +// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address +// in a loopback interface's associated subnet is bound to the permanently bound +// address. +func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { + const nicID = 1 + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + addrBytes := []byte(ipv4Addr.Address) + addrBytes[len(addrBytes)-1]++ + otherAddr := tcpip.Address(addrBytes) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err) + } + defer r.Release() + + params := stack.NetworkHeaderParams{ + Protocol: 111, + TTL: 64, + TOS: stack.DefaultTOS, + } + data := buffer.View([]byte{1, 2, 3, 4}) + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) + } + + // Removing the address should make the endpoint invalid. + if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) + } + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go new file mode 100644 index 000000000..4f2ca7f54 --- /dev/null +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -0,0 +1,556 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + defaultMTU = 1280 + ttl = 255 +) + +var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + ipv4Subnet = ipv4Addr.Subnet() + ipv4SubnetBcast = ipv4Subnet.Broadcast() + + ipv6Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("200a::1").To16()), + PrefixLen: 64, + } + ipv6Subnet = ipv6Addr.Subnet() + ipv6SubnetBcast = ipv6Subnet.Broadcast() + + // Remote addrs. + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + remoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16()) +) + +// TestPingMulticastBroadcast tests that responding to an Echo Request destined +// to a multicast or broadcast address uses a unicast source address for the +// reply. +func TestPingMulticastBroadcast(t *testing.T) { + const nicID = 1 + + rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + dstAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + dstAddr: ipv4Addr.Address, + }, + { + name: "IPv4 directed broadcast", + dstAddr: ipv4SubnetBcast, + }, + { + name: "IPv4 broadcast", + dstAddr: header.IPv4Broadcast, + }, + { + name: "IPv4 all-systems multicast", + dstAddr: header.IPv4AllSystems, + }, + { + name: "IPv6 unicast", + dstAddr: ipv6Addr.Address, + }, + { + name: "IPv6 all-nodes multicast", + dstAddr: header.IPv6AllNodesMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + }) + // We only expect a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + var rxICMP func(*channel.Endpoint, tcpip.Address) + var expectedSrc tcpip.Address + var expectedDst tcpip.Address + var protoNum tcpip.NetworkProtocolNumber + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + rxICMP = rxIPv4ICMP + expectedSrc = ipv4Addr.Address + expectedDst = remoteIPv4Addr + protoNum = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + rxICMP = rxIPv6ICMP + expectedSrc = ipv6Addr.Address + expectedDst = remoteIPv6Addr + protoNum = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + rxICMP(e, test.dstAddr) + pkt, ok := e.Read() + if !ok { + t.Fatal("expected ICMP response") + } + + if pkt.Route.LocalAddress != expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + } + if pkt.Route.RemoteAddress != expectedDst { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + } + + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if src != expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + } + if dst != expectedDst { + t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + } + }) + } + +} + +// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some +// multicast or broadcast address. +func TestIncomingMulticastAndBroadcast(t *testing.T) { + const ( + nicID = 1 + remotePort = 5555 + localPort = 80 + ) + + data := []byte{1, 2, 3, 4} + + rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 unicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 unicast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 unicast binding to wildcard", + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + + { + name: "IPv4 directed broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + { + name: "IPv4 directed broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, + }, + { + name: "IPv4 directed broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, + }, + { + name: "IPv4 broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, + }, + { + name: "IPv4 broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 all-systems multicast binding to all-systems multicast", + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to wildcard", + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, + }, + + // IPv6 has no notion of a broadcast. + { + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + var netproto tcpip.NetworkProtocolNumber + var rxUDP func(*channel.Endpoint, tcpip.Address) + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + netproto = header.IPv4ProtocolNumber + rxUDP = rxIPv4UDP + case header.IPv6AddressSize: + netproto = header.IPv6ProtocolNumber + rxUDP = rxIPv6UDP + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + } + + rxUDP(e, test.dstAddr) + if gotPayload, _, err := ep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} + +// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all +// interested endpoints. +func TestReuseAddrAndBroadcast(t *testing.T) { + const ( + nicID = 1 + localPort = 9000 + loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + ) + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + + tests := []struct { + name string + broadcastAddr tcpip.Address + }{ + { + name: "Subnet directed broadcast", + broadcastAddr: loopbackBroadcast, + }, + { + name: "IPv4 broadcast", + broadcastAddr: header.IPv4Broadcast, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x7f\x00\x00\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + // We use the empty subnet instead of just the loopback subnet so we + // also have a route to the IPv4 Broadcast address. + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // We create endpoints that bind to both the wildcard address and the + // broadcast address to make sure both of these types of "broadcast + // interested" endpoints receive broadcast packets. + wq := waiter.Queue{} + var eps []tcpip.Endpoint + for _, bindWildcard := range []bool{false, true} { + // Create multiple endpoints for each type of "broadcast interested" + // endpoint so we can test that all endpoints receive the broadcast + // packet. + for i := 0; i < 2; i++ { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + defer ep.Close() + + if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) + } + + bindAddr := tcpip.FullAddress{Port: localPort} + if bindWildcard { + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } else { + bindAddr.Addr = test.broadcastAddr + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } + + eps = append(eps, ep) + } + } + + for i, wep := range eps { + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.broadcastAddr, + Port: localPort, + }, + } + if n, _, err := wep.Write(data, writeOpts); err != nil { + t.Fatalf("eps[%d].Write(_, _): %s", i, err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + } + + for j, rep := range eps { + if gotPayload, _, err := rep.Read(nil); err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) + } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) + } + } + } + }) + } +} diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 7f172f978..606363567 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,14 +13,14 @@ // limitations under the License. // +build go1.9 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. package tcpip import ( - _ "time" // Used with go:linkname. + "time" // Used with go:linkname. _ "unsafe" // Required for go:linkname. ) @@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 { _, _, mono := now() return mono } + +// AfterFunc implements Clock.AfterFunc. +func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 59f3b391f..f1dd7c310 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -15,54 +15,54 @@ package tcpip import ( - "sync" "time" + + "gvisor.dev/gvisor/pkg/sync" ) -// cancellableTimerInstance is a specific instance of CancellableTimer. +// jobInstance is a specific instance of Job. // -// Different instances are created each time CancellableTimer is Reset so each -// timer has its own earlyReturn signal. This is to address a bug when a -// CancellableTimer is stopped and reset in quick succession resulting in a -// timer instance's earlyReturn signal being affected or seen by another timer -// instance. +// Different instances are created each time Job is scheduled so each timer has +// its own earlyReturn signal. This is to address a bug when a Job is stopped +// and reset in quick succession resulting in a timer instance's earlyReturn +// signal being affected or seen by another timer instance. // // Consider the following sceneario where timer instances share a common // earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a // lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second // (B), third (C), and fourth (D) instance of the timer firing, respectively): // T1: Obtain L -// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T1: Create a new Job w/ lock L (create instance A) // T2: instance A fires, blocked trying to obtain L. // T1: Attempt to stop instance A (set earlyReturn = true) -// T1: Reset timer (create instance B) +// T1: Schedule timer (create instance B) // T3: instance B fires, blocked trying to obtain L. // T1: Attempt to stop instance B (set earlyReturn = true) -// T1: Reset timer (create instance C) +// T1: Schedule timer (create instance C) // T4: instance C fires, blocked trying to obtain L. // T1: Attempt to stop instance C (set earlyReturn = true) -// T1: Reset timer (create instance D) +// T1: Schedule timer (create instance D) // T5: instance D fires, blocked trying to obtain L. // T1: Release L // -// Now that T1 has released L, any of the 4 timer instances can take L and check -// earlyReturn. If the timers simply check earlyReturn and then do nothing -// further, then instance D will never early return even though it was not -// requested to stop. If the timers reset earlyReturn before early returning, -// then all but one of the timers will do work when only one was expected to. -// If CancellableTimer resets earlyReturn when resetting, then all the timers +// Now that T1 has released L, any of the 4 timer instances can take L and +// check earlyReturn. If the timers simply check earlyReturn and then do +// nothing further, then instance D will never early return even though it was +// not requested to stop. If the timers reset earlyReturn before early +// returning, then all but one of the timers will do work when only one was +// expected to. If Job resets earlyReturn when resetting, then all the timers // will fire (again, when only one was expected to). // // To address the above concerns the simplest solution was to give each timer // its own earlyReturn signal. -type cancellableTimerInstance struct { - timer *time.Timer +type jobInstance struct { + timer Timer // Used to inform the timer to early return when it gets stopped while the // lock the timer tries to obtain when fired is held (T1 is a goroutine that // tries to cancel the timer and T2 is the goroutine that handles the timer // firing): - // T1: Obtain the lock, then call StopLocked() + // T1: Obtain the lock, then call Cancel() // T2: timer fires, and gets blocked on obtaining the lock // T1: Releases lock // T2: Obtains lock does unintended work @@ -73,27 +73,33 @@ type cancellableTimerInstance struct { earlyReturn *bool } -// stop stops the timer instance t from firing if it hasn't fired already. If it +// stop stops the job instance j from firing if it hasn't fired already. If it // has fired and is blocked at obtaining the lock, earlyReturn will be set to // true so that it will early return when it obtains the lock. -func (t *cancellableTimerInstance) stop() { - if t.timer != nil { - t.timer.Stop() - *t.earlyReturn = true +func (j *jobInstance) stop() { + if j.timer != nil { + j.timer.Stop() + *j.earlyReturn = true } } -// CancellableTimer is a timer that does some work and can be safely cancelled -// when it fires at the same time some "related work" is being done. +// Job represents some work that can be scheduled for execution. The work can +// be safely cancelled when it fires at the same time some "related work" is +// being done. // // The term "related work" is defined as some work that needs to be done while // holding some lock that the timer must also hold while doing some work. // -// Note, it is not safe to copy a CancellableTimer as its timer instance creates -// a closure over the address of the CancellableTimer. -type CancellableTimer struct { +// Note, it is not safe to copy a Job as its timer instance creates +// a closure over the address of the Job. +type Job struct { + _ sync.NoCopy + + // The clock used to schedule the backing timer + clock Clock + // The active instance of a cancellable timer. - instance cancellableTimerInstance + instance jobInstance // locker is the lock taken by the timer immediately after it fires and must // be held when attempting to stop the timer. @@ -110,75 +116,91 @@ type CancellableTimer struct { fn func() } -// StopLocked prevents the Timer from firing if it has not fired already. +// Cancel prevents the Job from executing if it has not executed already. // -// If the timer is blocked on obtaining the t.locker lock when StopLocked is -// called, it will early return instead of calling t.fn. +// Cancel requires appropriate locking to be in place for any resources managed +// by the Job. If the Job is blocked on obtaining the lock when Cancel is +// called, it will early return. // // Note, t will be modified. // -// t.locker MUST be locked. -func (t *CancellableTimer) StopLocked() { - t.instance.stop() +// j.locker MUST be locked. +func (j *Job) Cancel() { + j.instance.stop() // Nothing to do with the stopped instance anymore. - t.instance = cancellableTimerInstance{} + j.instance = jobInstance{} } -// Reset changes the timer to expire after duration d. +// Schedule schedules the Job for execution after duration d. This can be +// called on cancelled or completed Jobs to schedule them again. // -// Note, t will be modified. +// Schedule should be invoked only on unscheduled, cancelled, or completed +// Jobs. To be safe, callers should always call Cancel before calling Schedule. // -// Reset should only be called on stopped or expired timers. To be safe, callers -// should always call StopLocked before calling Reset. -func (t *CancellableTimer) Reset(d time.Duration) { +// Note, j will be modified. +func (j *Job) Schedule(d time.Duration) { // Create a new instance. earlyReturn := false // Capture the locker so that updating the timer does not cause a data race // when a timer fires and tries to obtain the lock (read the timer's locker). - locker := t.locker - t.instance = cancellableTimerInstance{ - timer: time.AfterFunc(d, func() { + locker := j.locker + j.instance = jobInstance{ + timer: j.clock.AfterFunc(d, func() { locker.Lock() defer locker.Unlock() if earlyReturn { // If we reach this point, it means that the timer fired while another - // goroutine called StopLocked while it had the lock. Simply return - // here and do nothing further. + // goroutine called Cancel while it had the lock. Simply return here + // and do nothing further. earlyReturn = false return } - t.fn() + j.fn() }), earlyReturn: &earlyReturn, } } -// Lock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Lock() {} - -// Unlock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Unlock() {} - -// NewCancellableTimer returns an unscheduled CancellableTimer with the given -// locker and fn. -// -// fn MUST NOT attempt to lock locker. -// -// Callers must call Reset to schedule the timer to fire. -func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer { - return &CancellableTimer{locker: locker, fn: fn} +// NewJob returns a new Job that can be used to schedule f to run in its own +// gorountine. l will be locked before calling f then unlocked after f returns. +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// message = "bar" +// mu.Unlock() +// +// // Output: bar +// +// f MUST NOT attempt to lock l. +// +// l MUST be locked prior to calling the returned job's Cancel(). +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// job.Cancel() +// mu.Unlock() +func NewJob(c Clock, l sync.Locker, f func()) *Job { + return &Job{ + clock: c, + locker: l, + fn: f, + } } diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index b4940e397..a82384c49 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -28,8 +28,8 @@ const ( longDuration = 1 * time.Second ) -func TestCancellableTimerReassignment(t *testing.T) { - var timer tcpip.CancellableTimer +func TestJobReschedule(t *testing.T) { + var clock tcpip.StdClock var wg sync.WaitGroup var lock sync.Mutex @@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - timer = *tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { wg.Done() }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) lock.Unlock() }() } wg.Wait() } -func TestCancellableTimerFire(t *testing.T) { +func TestJobExecution(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(middleDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(middleDuration) lock.Lock() - timer.StopLocked() + job.Cancel() lock.Unlock() - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { } } -func TestCancellableTimerResetFromShortDuration(t *testing.T) { +func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() // Wait for timer to fire if it wasn't correctly stopped. @@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { case <-time.After(middleDuration): } - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { } } -func TestCancellableTimerImmediatelyStop(t *testing.T) { +func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() } @@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) { } } -func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { +func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() for i := 0; i < 10; i++ { - timer.Reset(middleDuration) + job.Schedule(middleDuration) lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration * 2) - timer.StopLocked() + job.Cancel() lock.Unlock() } @@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration) - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() @@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { } } -func TestManyCancellableTimerResetUnderLock(t *testing.T) { +func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..41eb0ca44 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -74,6 +74,8 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -343,7 +345,16 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() + } return nil } @@ -411,9 +422,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() return nil default: @@ -426,9 +440,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + }) + pkt.Owner = owner - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -443,15 +461,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + pkt.Data = data.ToVectorisedView() + if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: data.ToVectorisedView(), - TransportHeader: buffer.View(icmpv4), - Owner: owner, - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -459,9 +474,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + }) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -473,15 +491,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err dataVV := data.ToVectorisedView() icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + pkt.Data = dataVV if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: dataVV, - TransportHeader: buffer.View(icmpv6), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } // checkV4MappedLocked determines the effective network protocol and converts @@ -600,7 +615,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -744,15 +759,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,12 +805,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader().View().ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() @@ -827,3 +848,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} + +// LastError implements tcpip.Endpoint.LastError. +func (*endpoint) LastError() *tcpip.Error { + return nil +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 74ef6541e..87d510f96 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -13,12 +13,7 @@ // limitations under the License. // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport -// protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and activated on the stack by passing -// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport -// protocols when calling stack.New(). Then endpoints can be created by passing -// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number -// when calling Stack.NewEndpoint(). +// protocols for use in ping. package icmp import ( @@ -42,6 +37,8 @@ const ( // protocol implements stack.TransportProtocol. type protocol struct { + stack *stack.Stack + number tcpip.TransportProtocolNumber } @@ -62,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue) + return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) + return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. @@ -104,17 +101,17 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -135,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4() stack.TransportProtocol { - return &protocol{ProtocolNumber4} +func NewProtocol4(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6() stack.TransportProtocol { - return &protocol{ProtocolNumber6} +func NewProtocol6(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber6} } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index a8f8454dd..072601d2d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -45,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -79,6 +82,13 @@ type endpoint struct { closed bool stats tcpip.TransportEndpointStats `state:"nosave"` bound bool + boundNIC tcpip.NICID + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -146,8 +156,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -172,16 +182,25 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - // TODO(b/129292371): Implement. +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -193,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes tcpip.ErrNotSupported. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrNotSupported } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return tcpip.ErrNotSupported } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with // Accept, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -229,12 +248,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound { - return tcpip.ErrAlreadyBound + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil } // Unregister endpoint with all the nics. ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { @@ -242,17 +263,18 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } ep.bound = true + ep.boundNIC = addr.NIC return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -277,8 +299,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. -func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + ep.mu.Lock() + ep.linger = *v + ep.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. @@ -330,13 +364,31 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } } +func (ep *endpoint) LastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + ep.mu.Lock() + *o = ep.linger + ep.mu.Unlock() + return nil + + default: + return tcpip.ErrNotSupported + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { +func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrNotSupported } @@ -393,48 +445,73 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(b/129292371): Return network protocol. - if len(pkt.LinkHeader) > 0 { + // TODO(gvisor.dev/issue/173): Return network protocol. + if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. - hdr := header.Ethernet(pkt.LinkHeader) + hdr := header.Ethernet(pkt.LinkHeader().View()) packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header. + var combinedVV buffer.VectorisedView + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + if pkt.PktType != tcpip.PacketOutgoing { + if pkt.LinkHeader().View().IsEmpty() { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - combinedVV := linkHeader.ToVectorisedView() - combinedVV.Append(pkt.Data) - packet.data = combinedVV } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() @@ -448,7 +525,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (*endpoint) State() uint32 { return 0 } diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 766c7648e..e37c00523 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -83,6 +84,8 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -108,6 +111,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. @@ -182,10 +186,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -263,7 +263,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -353,19 +353,24 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if !e.associated { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + if e.hdrIncluded { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { + }) + if err := route.WriteHeaderIncludedPacket(pkt); err != nil { return 0, nil, err } } else { - hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(payloadBytes).ToVectorisedView(), - Owner: e.owner, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = e.owner + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { return 0, nil, err } } @@ -443,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -458,7 +463,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -479,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -507,12 +512,31 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -561,9 +585,12 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() return nil default: @@ -577,6 +604,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -616,8 +649,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() @@ -667,16 +707,17 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // slice. Save/restore doesn't support overlapping slices and will fail. var combinedVV buffer.VectorisedView if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) - headers = append(headers, pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) combinedVV = headers.ToVectorisedView() } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } combinedVV.Append(pkt.Data) packet.data = combinedVV - packet.timestampNS = e.stack.NowNanoseconds() + packet.timestampNS = e.stack.Clock().NowNanoseconds() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() @@ -709,3 +750,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} + +func (*endpoint) LastError() *tcpip.Error { + return nil +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 18ff89ffc..518449602 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -40,6 +40,8 @@ go_library( "endpoint_state.go", "forwarder.go", "protocol.go", + "rack.go", + "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", @@ -49,6 +51,7 @@ go_library( "segment_heap.go", "segment_queue.go", "segment_state.go", + "segment_unsafe.go", "snd.go", "snd_state.go", "tcp_endpoint_list.go", @@ -66,6 +69,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", @@ -82,6 +86,7 @@ go_test( "dual_stack_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", + "tcp_rack_test.go", "tcp_sack_test.go", "tcp_test.go", "tcp_timestamp_test.go", @@ -89,6 +94,7 @@ go_test( shard_count = 10, deps = [ ":tcp", + "//pkg/rand", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6e00e5526..b706438bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -212,7 +212,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) - n.amss = mssForRoute(&n.route) + n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) @@ -380,6 +380,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.portFlags = e.portFlags n.boundBindToDevice = e.boundBindToDevice n.boundPortFlags = e.boundPortFlags + n.userMSS = e.userMSS } // reserveTupleLocked reserves an accepted endpoint's tuple. @@ -481,9 +482,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - // TODO(b/143300739): Use the userMSS of the listening socket - // for accepted sockets. - switch { case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -514,16 +512,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN without window scaling because we currently - // dont't encode this information in the cookie. + // don't encode this information in the cookie. // // Enable Timestamp option if the original syn did have // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(timeStampOffset()), + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: mssForRoute(&s.route), + MSS: calculateAdvertisedMSS(e.userMSS, s.route), } e.sendSynTCP(&s.route, tcpFields{ id: s.id, diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 81b740115..189c01c8f 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.LastError() + } } // Wait for notification. @@ -519,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled SACKEnabled + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. @@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.LastError() + } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -740,11 +746,8 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) - hdr := &pkt.Header - packetSize := pkt.Data.Size() - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) - pkt.TransportHeader = buffer.View(tcp) + tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) + pkt.TransportProtocolNumber = header.TCPProtocolNumber tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -756,8 +759,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta }) copy(tcp[header.TCPMinimumSize:], tf.opts) - length := uint16(hdr.UsedLength() + packetSize) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We @@ -795,17 +797,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso packetSize = size } size -= packetSize - var pkt stack.PacketBuffer - pkt.Header = buffer.NewPrependable(hdrSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrSize, + }) pkt.Hash = tf.txHash pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + pkt.NetworkProtocolNumber = r.NetProto data.ReadToVV(&pkt.Data, packetSize) - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) - pkts.PushBack(&pkt) + pkts.PushBack(pkt) } if tf.ttl == 0 { @@ -831,12 +834,12 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := &stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Data: data, - Hash: tf.txHash, - Owner: owner, - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, + Data: data, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { @@ -895,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -1000,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. @@ -1018,14 +1020,19 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // delivered to this endpoint from the demuxer when the endpoint // is transitioned to StateClose. func (e *endpoint) transitionToStateCloseLocked() { - if e.EndpointState() == StateClose { + s := e.EndpointState() + if s == StateClose { return } + + if s.connected() { + e.stack.Stats().TCP.CurrentConnected.Decrement() + e.stack.Stats().TCP.EstablishedClosed.Increment() + } + // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() e.setEndpointState(StateClose) - e.stack.Stats().TCP.CurrentConnected.Decrement() - e.stack.Stats().TCP.EstablishedClosed.Increment() } // tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed @@ -1128,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { } cont, err := e.handleSegment(s) + s.decRef() if err != nil { - s.decRef() return err } if !cont { - s.decRef() return nil } } @@ -1155,13 +1161,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { return nil } -// handleSegment handles a given segment and notifies the worker goroutine if -// if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { - // Invoke the tcp probe if installed. +func (e *endpoint) probeSegment() { if e.probe != nil { e.probe(e.completeState()) } +} + +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. The tcp probe function will update + // the TCPEndpointState after the segment is processed. + defer e.probeSegment() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { @@ -1208,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } + // Increase counter if after processing the segment we would potentially + // advertise a zero window. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. @@ -1220,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // or a notification from the protocolMainLoop (caller goroutine). // This means that with this return, the segment dequeue below can // never occur on a closed endpoint. - s.decRef() return false, nil } @@ -1412,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.rcv.nonZeroWindow() } - if n¬ifyReceiveWindowChanged != 0 { - e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) - } - if n¬ifyMTUChanged != 0 { e.sndBufMu.Lock() count := e.packetTooBigCount @@ -1690,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 || n¬ifyAbort != 0 { + if n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 804e95aea..560b4904c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -78,16 +78,15 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv4(t, c.GetPacket(), ackCheckers...) // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -186,16 +185,15 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv6(t, c.GetV6Packet(), ackCheckers...) // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -285,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -321,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -354,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -373,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -494,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -512,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() + var addr tcpip.FullAddress + nep, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -528,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) { } } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) } } @@ -568,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } const n = uint16(32) @@ -612,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bd3ec5a8d..ae817091a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,6 +63,17 @@ const ( StateClosing ) +const ( + // rcvAdvWndScale is used to split the available socket buffer into + // application buffer and the window to be advertised to the peer. This is + // currently hard coded to split the available space equally. + rcvAdvWndScale = 1 + + // SegOverheadFactor is used to multiply the value provided by the + // user on a SetSockOpt for setting the socket send/receive buffer sizes. + SegOverheadFactor = 2 +) + // connected returns true when s is one of the states representing an // endpoint connected to a peer. func (s EndpointState) connected() bool { @@ -149,7 +160,6 @@ func (s EndpointState) String() string { // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota - notifyReceiveWindowChanged notifyClose notifyMTUChanged notifyDrain @@ -384,13 +394,26 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + // rcvBufSize is the total size of the receive buffer. + rcvBufSize int + // rcvBufUsed is the actual number of payload bytes held in the receive buffer + // not counting any overheads of the segments itself. NOTE: This will always + // be strictly <= rcvMemUsed below. rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams + // rcvMemUsed tracks the total amount of memory in use by received segments + // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // compute the window and the actual available buffer space. This is distinct + // from rcvBufUsed above which is the actual number of payload bytes held in + // the buffer not including any segment overheads. + // + // rcvMemUsed must be accessed atomically. + rcvMemUsed int32 + // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. mu sync.Mutex `state:"nosave"` @@ -449,10 +472,11 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. - // - // recentTS must be read/written atomically. recentTS uint32 + // recentTSTime is the unix time when we updated recentTS last. + recentTSTime time.Time `state:".(unixTime)"` + // tsOffset is a randomized offset added to the value of the // TSVal field in the timestamp option. tsOffset uint32 @@ -653,6 +677,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -666,7 +693,8 @@ func (e *endpoint) UniqueID() uint64 { // r, it will be used; otherwise, the maximum possible MSS will be used. func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { // The maximum possible MSS is dependent on the route. - maxMSS := mssForRoute(&r) + // TODO(b/143359391): Respect TCP Min and Max size. + maxMSS := uint16(r.MTU() - header.TCPMinimumSize) if userMSS != 0 && userMSS < maxMSS { return userMSS @@ -795,15 +823,15 @@ func (e *endpoint) EndpointState() EndpointState { return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) } -// setRecentTimestamp atomically sets the recentTS field to the -// provided value. +// setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - atomic.StoreUint32(&e.recentTS, recentTS) + e.recentTS = recentTS + e.recentTSTime = time.Now() } -// recentTimestamp atomically reads and returns the value of the recentTS field. +// recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return atomic.LoadUint32(&e.recentTS) + return e.recentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -847,12 +875,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -862,12 +890,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.cc = cs } - var mrb tcpip.ModerateReceiveBufferOption + var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - var de DelayEnabled + var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -886,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.probe = p } - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) @@ -899,10 +927,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) switch e.EndpointState() { - case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: + case StateInitial, StateBound: + // This prevents blocking of new sockets which are not + // connected when SO_LINGER is set. + result |= waiter.EventHUp + + case StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case StateClose, StateError: + case StateClose, StateError, StateTimeWait: // Ready for anything. result = mask @@ -1005,6 +1038,26 @@ func (e *endpoint) Close() { return } + if e.linger.Enabled && e.linger.Timeout == 0 { + s := e.EndpointState() + isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv + if isResetState { + // Close the endpoint without doing full shutdown and + // send a RST. + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.closeNoShutdownLocked() + + // Wake up worker to close the endpoint. + switch s { + case StateSynRecv: + e.notifyProtocolGoroutine(notifyClose) + default: + e.notifyProtocolGoroutine(notifyTickleWorker) + } + return + } + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -1050,6 +1103,8 @@ func (e *endpoint) closeNoShutdownLocked() { e.notifyProtocolGoroutine(notifyClose) } else { e.transitionToStateCloseLocked() + // Notify that the endpoint is closed. + e.waiterQueue.Notify(waiter.EventHUp) } } @@ -1104,10 +1159,16 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// wndFromSpace returns the window that we can advertise based on the available +// receive buffer space. +func wndFromSpace(space int) int { + return space / (1 << rcvAdvWndScale) +} + // initialReceiveWindow returns the initial receive window to advertise in the // SYN/SYN-ACK. func (e *endpoint) initialReceiveWindow() int { - rcvWnd := e.receiveBufferAvailable() + rcvWnd := wndFromSpace(e.receiveBufferAvailable()) if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } @@ -1184,14 +1245,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = rcvWnd - availAfter := e.receiveBufferAvailableLocked() - mask := uint32(notifyReceiveWindowChanged) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -1209,6 +1268,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +func (e *endpoint) LastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1260,18 +1327,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { v := views[s.viewToDeliver] s.viewToDeliver++ + var delta int if s.viewToDeliver >= len(views) { e.rcvList.Remove(s) + // We only free up receive buffer space when the segment is released as the + // segment is still holding on to the views even though some views have been + // read out to the user. + delta = s.segMemSize() s.decRef() } e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1284,14 +1355,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { - // The endpoint cannot be written to if it's not connected. - if !e.EndpointState().connected() { - switch e.EndpointState() { - case StateError: - return 0, e.HardError - default: - return 0, tcpip.ErrClosedForSend - } + switch s := e.EndpointState(); { + case s == StateError: + return 0, e.HardError + case !s.connecting() && !s.connected(): + return 0, tcpip.ErrClosedForSend + case s.connecting(): + // As per RFC793, page 56, a send request arriving when in connecting + // state, can be queued to be completed after the state becomes + // connected. Return an error code for the caller of endpoint Write to + // try again, until the connection handshake is complete. + return 0, tcpip.ErrWouldBlock } // Check if the connection has already been closed for sends. @@ -1445,11 +1519,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } // windowCrossedACKThresholdLocked checks if the receive window to be announced -// now would be under aMSS or under half receive buffer, whichever smaller. This -// is useful as a receive side silly window syndrome prevention mechanism. If -// window grows to reasonable value, we should send ACK to the sender to inform -// the rx space is now large. We also want ensure a series of small read()'s -// won't trigger a flood of spurious tiny ACK's. +// would be under aMSS or under the window derived from half receive buffer, +// whichever smaller. This is useful as a receive side silly window syndrome +// prevention mechanism. If window grows to reasonable value, we should send ACK +// to the sender to inform the rx space is now large. We also want ensure a +// series of small read()'s won't trigger a flood of spurious tiny ACK's. // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of @@ -1460,17 +1534,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := e.receiveBufferAvailableLocked() + newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 } - threshold := int(e.amss) - if threshold > e.rcvBufSize/2 { - threshold = e.rcvBufSize / 2 + // rcvBufFraction is the inverse of the fraction of receive buffer size that + // is used to decide if the available buffer space is now above it. + const rcvBufFraction = 2 + if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + threshold = wndThreshold } - switch { case oldAvail < threshold && newAvail >= threshold: return true, true @@ -1589,21 +1664,34 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs ReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs tcpip.TCPReceiveBufferSizeRangeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) + } + + if v > rs.Max { + v = rs.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < rs.Min { v = rs.Min } - if v > rs.Max { - v = rs.Max - } + } else { + v = math.MaxInt32 } - mask := uint32(notifyReceiveWindowChanged) - e.LockUser() e.rcvListMu.Lock() @@ -1617,14 +1705,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { v = 1 << scale } - // Make sure 2*size doesn't overflow. - if v > math.MaxInt32/2 { - v = math.MaxInt32 / 2 - } - - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = v - availAfter := e.receiveBufferAvailableLocked() + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvAutoParams.disabled = true @@ -1632,24 +1715,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } e.rcvListMu.Unlock() e.UnlockUser() - e.notifyProtocolGoroutine(mask) case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) + } + + if v > ss.Max { + v = ss.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < ss.Min { v = ss.Min } - if v > ss.Max { - v = ss.Max - } + } else { + v = math.MaxInt32 } e.sndBufMu.Lock() @@ -1682,7 +1772,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1696,10 +1786,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case tcpip.BindToDeviceOption: - id := tcpip.NICID(v) + case *tcpip.BindToDeviceOption: + id := tcpip.NICID(*v) if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } @@ -1707,40 +1797,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = id e.UnlockUser() - case tcpip.KeepaliveIdleOption: + case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() - e.keepalive.idle = time.Duration(v) + e.keepalive.idle = time.Duration(*v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.KeepaliveIntervalOption: + case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() - e.keepalive.interval = time.Duration(v) + e.keepalive.interval = time.Duration(*v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.OutOfBandInlineOption: + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. - case tcpip.TCPUserTimeoutOption: + case *tcpip.TCPUserTimeoutOption: e.LockUser() - e.userTimeout = time.Duration(v) + e.userTimeout = time.Duration(*v) e.UnlockUser() - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and // validate that the specified algorithm is actually // supported in the stack. - var avail tcpip.AvailableCongestionControlOption + var avail tcpip.TCPAvailableCongestionControlOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { return err } availCC := strings.Split(string(avail), " ") for _, cc := range availCC { - if v == tcpip.CongestionControlOption(cc) { + if *v == tcpip.CongestionControlOption(cc) { e.LockUser() state := e.EndpointState() - e.cc = v + e.cc = *v switch state { case StateEstablished: if e.EndpointState() == state { @@ -1756,33 +1846,43 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.TCPLingerTimeoutOption: + case *tcpip.TCPLingerTimeoutOption: e.LockUser() - if v < 0 { + + switch { + case *v < 0: // Same as effectively disabling TCPLinger timeout. - v = 0 - } - var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption - if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { - // We were unable to retrieve a stack config, just use - // the DefaultTCPLingerTimeout. - if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { - stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + *v = -1 + case *v == 0: + // Same as the stack default. + var stackLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err)) } + *v = stackLingerTimeout + case *v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout): + // Cap it to Stack's default TCP_LINGER2 timeout. + *v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) + default: } - // Cap it to the stack wide TCPLinger timeout. - if v > stkTCPLingerTimeout { - v = stkTCPLingerTimeout - } - e.tcpLingerTimeout = time.Duration(v) + + e.tcpLingerTimeout = time.Duration(*v) e.UnlockUser() - case tcpip.TCPDeferAcceptOption: + case *tcpip.TCPDeferAcceptOption: e.LockUser() - if time.Duration(v) > MaxRTO { - v = tcpip.TCPDeferAcceptOption(MaxRTO) + if time.Duration(*v) > MaxRTO { + *v = tcpip.TCPDeferAcceptOption(MaxRTO) } - e.deferAccept = time.Duration(v) + e.deferAccept = time.Duration(*v) + e.UnlockUser() + + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + e.LockUser() + e.linger = *v e.UnlockUser() default: @@ -1896,6 +1996,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := header.TCPDefaultMSS return v, nil + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1938,15 +2043,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) @@ -1998,6 +2096,24 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() + case *tcpip.OriginalDestinationOption: + e.LockUser() + ipt := e.stack.IPTables() + addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + e.UnlockUser() + if err != nil { + return err + } + *o = tcpip.OriginalDestinationOption{ + Addr: addr, + Port: port, + } + + case *tcpip.LingerOption: + e.LockUser() + *o = e.linger + e.UnlockUser() + default: return tcpip.ErrUnknownProtocolOption } @@ -2125,12 +2241,66 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc h.Write(portBuf) portOffset := h.Sum32() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err)) + } + + reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal + if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { + switch netProto { + case header.IPv4ProtocolNumber: + reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + case header.IPv6ProtocolNumber: + reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + } + } + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { - return false, nil + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if err != tcpip.ErrPortInUse || !reuse { + return false, nil + } + transEPID := e.ID + transEPID.LocalPort = p + // Check if an endpoint is registered with demuxer in TIME-WAIT and if + // we can reuse it. If we can't find a transport endpoint then we just + // skip using this port as it's possible that either an endpoint has + // bound the port but not registered with demuxer yet (no listen/connect + // done yet) or the reservation was freed between the check above and + // the FindTransportEndpoint below. But rather than retry the same port + // we just skip it and move on. + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + if transEP == nil { + // ReservePort failed but there is no registered endpoint with + // demuxer. Which indicates there is at least some endpoint that has + // bound the port. + return false, nil + } + + tcpEP := transEP.(*endpoint) + tcpEP.LockUser() + // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but + // less than 1 second has elapsed since its recentTS was updated then + // we cannot reuse the port. + if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + tcpEP.UnlockUser() + return false, nil + } + // Since the endpoint is in TIME-WAIT it should be safe to acquire its + // Lock while holding the lock for this endpoint as endpoints in + // TIME-WAIT do not acquire locks on other endpoints. + tcpEP.workerCleanup = false + tcpEP.cleanupLocked() + tcpEP.notifyProtocolGoroutine(notifyAbort) + tcpEP.UnlockUser() + // Now try and Reserve again if it fails then we skip. + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + return false, nil + } } id := e.ID @@ -2368,7 +2538,9 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +// +// addr if not-nil will contain the peer address of the returned endpoint. +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2390,6 +2562,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } + if peerAddr != nil { + *peerAddr = n.getRemoteAddress() + } return n, n.waiterQueue, nil } @@ -2426,47 +2601,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) - if err != nil { - return err - } - - e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = e.portFlags - e.isPortReserved = true - e.effectiveNetProtos = netProtos - e.ID.LocalPort = port - - // Any failures beyond this point must remove the port registration. - defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { - if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) - e.isPortReserved = false - e.effectiveNetProtos = nil - e.ID.LocalPort = 0 - e.ID.LocalAddress = "" - e.boundNICID = 0 - e.boundBindToDevice = 0 - e.boundPortFlags = ports.Flags{} - } - }(e.boundPortFlags, e.boundBindToDevice) - + var nic tcpip.NICID // If an address is specified, we must ensure that it's one of our // local addresses. if len(addr.Addr) != 0 { - nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { return tcpip.ErrBadLocalAddress } - - e.boundNICID = nic e.ID.LocalAddress = addr.Addr } - if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + id := e.ID + id.LocalPort = p + // CheckRegisterTransportEndpoint should only return an error if there is a + // listening endpoint bound with the same id and portFlags and bindToDevice + // options. + // + // NOTE: Only listening and connected endpoint register with + // demuxer. Further connected endpoints always have a remote + // address/port. Hence this will only return an error if there is a matching + // listening endpoint. + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + return false + } + return true + }) + if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. + e.boundNICID = nic + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.ID.LocalPort = port + // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2494,11 +2667,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotConnected } + return e.getRemoteAddress(), nil +} + +func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort, NIC: e.boundNICID, - }, nil + } } func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -2531,6 +2708,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } @@ -2557,13 +2746,8 @@ func (e *endpoint) updateSndBufferUsage(v int) { func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { + e.rcvBufUsed += s.payloadSize() s.incRef() - e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down below MSS - // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -2578,15 +2762,17 @@ func (e *endpoint) readyToRead(s *segment) { func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if e.rcvBufUsed >= e.rcvBufSize { + memUsed := e.receiveMemUsed() + if memUsed >= e.rcvBufSize { return 0 } - return e.rcvBufSize - e.rcvBufUsed + return e.rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. +// receive buffer based on the actual memory used by all segments held in +// receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvListMu.Lock() available := e.receiveBufferAvailableLocked() @@ -2594,16 +2780,37 @@ func (e *endpoint) receiveBufferAvailable() int { return available } +// receiveBufferUsed returns the amount of in-use receive buffer. +func (e *endpoint) receiveBufferUsed() int { + e.rcvListMu.Lock() + used := e.rcvBufUsed + e.rcvListMu.Unlock() + return used +} + +// receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { e.rcvListMu.Lock() size := e.rcvBufSize e.rcvListMu.Unlock() - return size } +// receiveMemUsed returns the total memory in use by segments held by this +// endpoint. +func (e *endpoint) receiveMemUsed() int { + return int(atomic.LoadInt32(&e.rcvMemUsed)) +} + +// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed. +func (e *endpoint) updateReceiveMemUsed(delta int) { + atomic.AddInt32(&e.rcvMemUsed, int32(delta)) +} + +// maxReceiveBufferSize returns the stack wide maximum receive buffer size for +// an endpoint. func (e *endpoint) maxReceiveBufferSize() int { - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2650,15 +2857,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(e.tsOffset) + return tcpTimeStamp(time.Now(), e.tsOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(offset uint32) uint32 { - now := time.Now() - return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { + return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending @@ -2684,7 +2890,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled + var v tcpip.TCPSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return @@ -2753,7 +2959,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RcvAcc: e.rcv.rcvAcc, RcvWndScale: e.rcv.rcvWndScale, PendingBufUsed: e.rcv.pendingBufUsed, - PendingBufSize: e.rcv.pendingBufSize, } // Copy sender state. @@ -2801,6 +3006,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState { WEst: cubic.wEst, } } + + rc := e.snd.rc + s.Sender.RACKState = stack.TCPRACKState{ + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + } return s } @@ -2869,8 +3082,3 @@ func (e *endpoint) Wait() { <-notifyCh } } - -func mssForRoute(r *stack.Route) uint16 { - // TODO(b/143359391): Respect TCP Min and Max size. - return uint16(r.MTU() - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index abf1ac5c9..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() { // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. - e.segmentQueue.setLimit(0) + e.segmentQueue.freeze() e.mu.Lock() defer e.mu.Unlock() @@ -178,18 +178,18 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) @@ -309,6 +309,16 @@ func (e *endpoint) loadLastError(s string) { e.lastError = tcpip.StringToError(s) } +// saveRecentTSTime is invoked by stateify. +func (e *endpoint) saveRecentTSTime() unixTime { + return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} +} + +// loadRecentTSTime is invoked by stateify. +func (e *endpoint) loadRecentTSTime(unix unixTime) { + e.recentTSTime = time.Unix(unix.second, unix.nano) +} + // saveHardError is invoked by stateify. func (e *EndpointInfo) saveHardError() string { if e.HardError == nil { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b34e47bbd..5bce73605 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -12,16 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package tcp contains the implementation of the TCP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing tcp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package tcp contains the implementation of the TCP transport protocol. package tcp import ( - "fmt" "runtime" "strings" "time" @@ -30,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -62,6 +57,10 @@ const ( // FIN_WAIT_2 state before being marked closed. DefaultTCPLingerTimeout = 60 * time.Second + // MaxTCPLingerTimeout is the maximum amount of time that sockets + // linger in FIN_WAIT_2 state before being marked closed. + MaxTCPLingerTimeout = 120 * time.Second + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger // in TIME_WAIT state before being marked closed. DefaultTCPTimeWaitTimeout = 60 * time.Second @@ -76,31 +75,6 @@ const ( ccCubic = "cubic" ) -// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type DelayEnabled bool - -// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max TCP send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// TCP receive buffer sizes. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -159,16 +133,20 @@ func (s *synRcvdCounter) Threshold() uint64 { } type protocol struct { + stack *stack.Stack + mu sync.RWMutex sackEnabled bool + recovery tcpip.TCPRecovery delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.TCPSendBufferSizeRangeOption + recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool - tcpLingerTimeout time.Duration - tcpTimeWaitTimeout time.Duration + lingerTimeout time.Duration + timeWaitTimeout time.Duration + timeWaitReuse tcpip.TCPTimeWaitReuseOption minRTO time.Duration maxRTO time.Duration maxRetries uint32 @@ -183,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid tcp packet size. @@ -220,21 +198,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { - return false + return stack.UnknownDestinationPacketMalformed } - // There's nothing to do if this is already a reset packet. - if s.flagIsSet(header.TCPFlagRst) { - return true + if !s.flagIsSet(header.TCPFlagRst) { + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) } - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) - return true + return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. @@ -272,43 +249,49 @@ func replyWithReset(s *segment, tos, ttl uint8) { } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case *tcpip.TCPSACKEnabled: + p.mu.Lock() + p.sackEnabled = bool(*v) + p.mu.Unlock() + return nil + + case *tcpip.TCPRecovery: p.mu.Lock() - p.sackEnabled = bool(v) + p.recovery = *v p.mu.Unlock() return nil - case DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.Lock() - p.delayEnabled = bool(v) + p.delayEnabled = bool(*v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.sendBufferSize = v + p.sendBufferSize = *v p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.recvBufferSize = v + p.recvBufferSize = *v p.mu.Unlock() return nil - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { - if string(v) == c { + if string(*v) == c { p.mu.Lock() - p.congestionControl = string(v) + p.congestionControl = string(*v) p.mu.Unlock() return nil } @@ -317,66 +300,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // is specified. return tcpip.ErrNoSuchFile - case tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() - p.moderateReceiveBuffer = bool(v) + p.moderateReceiveBuffer = bool(*v) p.mu.Unlock() return nil - case tcpip.TCPLingerTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPLingerTimeoutOption: p.mu.Lock() - p.tcpLingerTimeout = time.Duration(v) + if *v < 0 { + p.lingerTimeout = 0 + } else { + p.lingerTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPTimeWaitTimeoutOption: p.mu.Lock() - p.tcpTimeWaitTimeout = time.Duration(v) + if *v < 0 { + p.timeWaitTimeout = 0 + } else { + p.timeWaitTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMinRTOOption: - if v < 0 { - v = tcpip.TCPMinRTOOption(MinRTO) + case *tcpip.TCPTimeWaitReuseOption: + if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { + return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.minRTO = time.Duration(v) + p.timeWaitReuse = *v p.mu.Unlock() return nil - case tcpip.TCPMaxRTOOption: - if v < 0 { - v = tcpip.TCPMaxRTOOption(MaxRTO) + case *tcpip.TCPMinRTOOption: + p.mu.Lock() + if *v < 0 { + p.minRTO = MinRTO + } else { + p.minRTO = time.Duration(*v) } + p.mu.Unlock() + return nil + + case *tcpip.TCPMaxRTOOption: p.mu.Lock() - p.maxRTO = time.Duration(v) + if *v < 0 { + p.maxRTO = MaxRTO + } else { + p.maxRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRetriesOption: + case *tcpip.TCPMaxRetriesOption: p.mu.Lock() - p.maxRetries = uint32(v) + p.maxRetries = uint32(*v) p.mu.Unlock() return nil - case tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(v)) + p.synRcvdCount.SetThreshold(uint64(*v)) p.mu.Unlock() return nil - case tcpip.TCPSynRetriesOption: - if v < 1 || v > 255 { + case *tcpip.TCPSynRetriesOption: + if *v < 1 || *v > 255 { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.synRetries = uint8(v) + p.synRetries = uint8(*v) p.mu.Unlock() return nil @@ -386,27 +382,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.TCPSACKEnabled: + p.mu.RLock() + *v = tcpip.TCPSACKEnabled(p.sackEnabled) + p.mu.RUnlock() + return nil + + case *tcpip.TCPRecovery: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.TCPRecovery(p.recovery) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.TCPDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -418,27 +420,33 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil - case *tcpip.AvailableCongestionControlOption: + case *tcpip.TCPAvailableCongestionControlOption: p.mu.RLock() - *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.RUnlock() return nil - case *tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.RLock() - *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer) p.mu.RUnlock() return nil case *tcpip.TCPLingerTimeoutOption: p.mu.RLock() - *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout) p.mu.RUnlock() return nil case *tcpip.TCPTimeWaitTimeoutOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPTimeWaitReuseOption: + p.mu.RLock() + *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) p.mu.RUnlock() return nil @@ -495,46 +503,34 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - - // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(offset) - if !ok { - panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) - } - } - - pkt.TransportHeader = hdr - pkt.Data.TrimFront(len(hdr)) - return true + return parse.TCP(pkt) } // NewProtocol returns a TCP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(s *stack.Stack) stack.TransportProtocol { p := protocol{ - sendBufferSize: SendBufferSizeOption{ + stack: s, + sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: ReceiveBufferSizeOption{ + recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, }, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, - tcpLingerTimeout: DefaultTCPLingerTimeout, - tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + lingerTimeout: DefaultTCPLingerTimeout, + timeWaitTimeout: DefaultTCPTimeWaitTimeout, + timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, + recovery: tcpip.TCPRACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go new file mode 100644 index 000000000..439932595 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// RACK is a loss detection algorithm used in TCP to detect packet loss and +// reordering using transmission timestamp of the packets instead of packet or +// sequence counts. To use RACK, SACK should be enabled on the connection. + +// rackControl stores the rack related fields. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1 +// +// +stateify savable +type rackControl struct { + // xmitTime is the latest transmission timestamp of rackControl.seg. + xmitTime time.Time `state:".(unixTime)"` + + // endSequence is the ending TCP sequence number of rackControl.seg. + endSequence seqnum.Value + + // fack is the highest selectively or cumulatively acknowledged + // sequence. + fack seqnum.Value + + // minRTT is the estimated minimum RTT of the connection. + minRTT time.Duration + + // rtt is the RTT of the most recently delivered packet on the + // connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + rtt time.Duration +} + +// Update will update the RACK related fields when an ACK has been received. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +func (rc *rackControl) Update(seg *segment, ackSeg *segment, offset uint32) { + rtt := time.Now().Sub(seg.xmitTime) + + // If the ACK is for a retransmitted packet, do not update if it is a + // spurious inference which is determined by below checks: + // 1. When Timestamping option is available, if the TSVal is less than the + // transmit time of the most recent retransmitted packet. + // 2. When RTT calculated for the packet is less than the smoothed RTT + // for the connection. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // step 2 + if seg.xmitCount > 1 { + if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + return + } + } + if rtt < rc.minRTT { + return + } + } + + rc.rtt = rtt + + // The sender can either track a simple global minimum of all RTT + // measurements from the connection, or a windowed min-filtered value + // of recent RTT measurements. This implementation keeps track of the + // simple global minimum of all RTTs for the connection. + if rtt < rc.minRTT || rc.minRTT == 0 { + rc.minRTT = rtt + } + + // Update rc.xmitTime and rc.endSequence to the transmit time and + // ending sequence number of the packet which has been acknowledged + // most recently. + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { + rc.xmitTime = seg.xmitTime + rc.endSequence = endSeq + } +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go new file mode 100644 index 000000000..c9dc7e773 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveXmitTime is invoked by stateify. +func (rc *rackControl) saveXmitTime() unixTime { + return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (rc *rackControl) loadXmitTime(unix unixTime) { + rc.xmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index dd89a292a..48bf196d8 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -47,22 +47,24 @@ type receiver struct { closed bool + // pendingRcvdSegments is bounded by the receive buffer size of the + // endpoint. pendingRcvdSegments segmentHeap - pendingBufUsed seqnum.Size - pendingBufSize seqnum.Size + // pendingBufUsed tracks the total number of bytes (including segment + // overhead) currently queued in pendingRcvdSegments. + pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, lastRcvdAckTime: time.Now(), } } @@ -85,15 +87,30 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the available buffer space. - receiveBufferAvailable := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) - if r.rcvAcc.LessThan(acc) { - r.rcvAcc = acc + avail := wndFromSpace(r.ep.receiveBufferAvailable()) + if avail == 0 { + // We have no space available to accept any data, move to zero window + // state. + r.rcvWnd = 0 + return r.rcvNxt, 0 + } + + acc := r.rcvNxt.Add(seqnum.Size(avail)) + newWnd := r.rcvNxt.Size(acc) + curWnd := r.rcvNxt.Size(r.rcvAcc) + + // Update rcvAcc only if new window is > previously advertised window. We + // should never shrink the acceptable sequence space once it has been + // advertised the peer. If we shrink the acceptable sequence space then we + // would end up dropping bytes that might already be in flight. + if newWnd > curWnd { + r.rcvAcc = r.rcvNxt.Add(newWnd) + } else { + newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. - r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + r.rcvWnd = newWnd return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } @@ -195,7 +212,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of // truncated items, thus truncated items must be set to nil to avoid // memory leaks. @@ -268,14 +287,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { - case StateCloseWait: - // If the ACK acks something not yet sent then we send an ACK. - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { - r.ep.snd.sendAck() - return true, nil - } - fallthrough - case StateClosing, StateLastAck: + case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { // Just drop the segment as we have // already received a FIN and this @@ -284,9 +296,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo return true, nil } fallthrough - case StateFinWait1: - fallthrough - case StateFinWait2: + case StateFinWait1, StateFinWait2: + // If the ACK acks something not yet sent then we send an ACK. + // + // RFC793, page 37: If the connection is in a synchronized state, + // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, + // TIME-WAIT), any unacceptable segment (out of window sequence number + // or unacceptable acknowledgment number) must elicit only an empty + // acknowledgment segment containing the current send-sequence number + // and an acknowledgment indicating the next sequence number expected + // to be received, and the connection remains in the same state. + // + // Just as on Linux, we do not apply this behavior when state is + // ESTABLISHED. + // Linux receive processing for all states except ESTABLISHED and + // TIME_WAIT is here where if the ACK check fails, we attempt to + // reply back with an ACK with correct seq/ack numbers. + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186 + // The ESTABLISHED state processing is here where if the ACK check + // fails, we ignore the packet: + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., // SHUT_RD) then any data past the rcvNxt should @@ -369,10 +403,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { - // We only store the segment if it's within our buffer - // size limit. - if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += s.logicalLen() + // We only store the segment if it's within our buffer size limit. + // + // Only use 75% of the receive buffer queue for out-of-order + // segments. This ensures that we always leave some space for the inorder + // segments to arrive allowing pending segments to be processed and + // delivered to the user. + if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvListMu.Lock() + r.pendingBufUsed += s.segMemSize() + r.ep.rcvListMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -406,7 +446,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= s.logicalLen() + r.ep.rcvListMu.Lock() + r.pendingBufUsed -= s.segMemSize() + r.ep.rcvListMu.Unlock() s.decRef() } return false, nil @@ -421,6 +463,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // Just silently drop any RST packets in TIME_WAIT. We do not support // TIME_WAIT assasination as a result we confirm w/ fix 1 as described // in https://tools.ietf.org/html/rfc1337#section-3. + // + // This behavior overrides RFC793 page 70 where we transition to CLOSED + // on receiving RST, which is also default Linux behavior. + // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337. + // + // As we do not yet support PAWS, we are being conservative in ignoring + // RSTs by default. if s.flagIsSet(header.TCPFlagRst) { return false, false } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 0280892a8..13acaf753 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "sync/atomic" "time" @@ -24,6 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// queueFlags are used to indicate which queue of an endpoint a particular segment +// belongs to. This is used to track memory accounting correctly. +type queueFlags uint8 + +const ( + recvQ queueFlags = 1 << iota + sendQ +) + // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. // segment is mostly immutable, the only field allowed to change is viewToDeliver. @@ -32,6 +42,8 @@ import ( type segment struct { segmentEntry refCnt int32 + ep *endpoint + qFlags queueFlags id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -68,7 +80,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) - s.hdr = header.TCP(pkt.TransportHeader) + s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() return s } @@ -100,6 +112,8 @@ func (s *segment) clone() *segment { rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, + ep: s.ep, + qFlags: s.qFlags, } t.data = s.data.Clone(t.views[:]) return t @@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool { return s.flags&flags == flags } +// setOwner sets the owning endpoint for this segment. Its required +// to be called to ensure memory accounting for receive/send buffer +// queues is done properly. +func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) { + switch qFlags { + case recvQ: + ep.updateReceiveMemUsed(s.segMemSize()) + case sendQ: + // no memory account for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b", qFlags)) + } + s.ep = ep + s.qFlags = qFlags +} + func (s *segment) decRef() { if atomic.AddInt32(&s.refCnt, -1) == 0 { + if s.ep != nil { + switch s.qFlags { + case recvQ: + s.ep.updateReceiveMemUsed(-s.segMemSize()) + case sendQ: + // no memory accounting for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) + } + } s.route.Release() } } @@ -138,6 +178,17 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// payloadSize is the size of s.data. +func (s *segment) payloadSize() int { + return s.data.Size() +} + +// segMemSize is the amount of memory used to hold the segment data and +// the associated metadata. +func (s *segment) segMemSize() int { + return segSize + s.data.Size() +} + // parse populates the sequence & ack numbers, flags, and window fields of the // segment from the TCP header stored in the data. It then updates the view to // skip the header. diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 48a257137..54545a1b1 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -22,16 +22,16 @@ import ( // // +stateify savable type segmentQueue struct { - mu sync.Mutex `state:"nosave"` - list segmentList `state:"wait"` - limit int - used int + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + ep *endpoint + frozen bool } // emptyLocked determines if the queue is empty. // Preconditions: q.mu must be held. func (q *segmentQueue) emptyLocked() bool { - return q.used == 0 + return q.list.Empty() } // empty determines if the queue is empty. @@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool { return r } -// setLimit updates the limit. No segments are immediately dropped in case the -// queue becomes full due to the new limit. -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} - // enqueue adds the given segment to the queue. // // Returns true when the segment is successfully added to the queue, in which @@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) { // false if the queue is full, in which case ownership is retained by the // caller. func (q *segmentQueue) enqueue(s *segment) bool { + // q.ep.receiveBufferParams() must be called without holding q.mu to + // avoid lock order inversion. + bufSz := q.ep.receiveBufferSize() + used := q.ep.receiveMemUsed() q.mu.Lock() - r := q.used < q.limit - if r { + // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue + // is currently full). + allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + + if allow { q.list.PushBack(s) - q.used++ + // Set the owner now that the endpoint owns the segment. + s.setOwner(q.ep, recvQ) } q.mu.Unlock() - return r + return allow } // dequeue removes and returns the next segment from queue, if one exists. @@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used-- } q.mu.Unlock() return s } + +// freeze prevents any more segments from being added to the queue. i.e all +// future segmentQueue.enqueue will return false and not add the segment to the +// queue till the queue is unfroze with a corresponding segmentQueue.thaw call. +func (q *segmentQueue) freeze() { + q.mu.Lock() + q.frozen = true + q.mu.Unlock() +} + +// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and +// allows new segments to be queued again. +func (q *segmentQueue) thaw() { + q.mu.Lock() + q.frozen = false + q.mu.Unlock() +} diff --git a/tools/nogo/data/data.go b/pkg/tcpip/transport/tcp/segment_unsafe.go index eb84d0d27..0ab7b8f56 100644 --- a/tools/nogo/data/data.go +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package data contains shared data for nogo analysis. -// -// This is used to break a dependency cycle. -package data +package tcp + +import ( + "unsafe" +) -// Objdump is the dumped binary under analysis. -var Objdump string +const ( + segSize = int(unsafe.Sizeof(segment{})) +) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 5862c32f2..4c9a86cda 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -191,6 +191,10 @@ type sender struct { // cc is the congestion control algorithm in use for this sender. cc congestionControl + + // rc has the fields needed for implementing RACK loss detection + // algorithm. + rc rackControl } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -1272,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. -func (s *sender) handleRcvdSegment(seg *segment) { +func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { s.updateRTO(time.Now().Sub(s.rttMeasureTime)) s.rttMeasureSeqNum = s.sndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && seg.parsedOptions.TS { - s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. if s.ep.sackPermitted { - for _, sb := range seg.parsedOptions.SACKBlocks { + for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: // * SACK block acks data after the ack number in the @@ -1299,27 +1303,27 @@ func (s *sender) handleRcvdSegment(seg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) - seg.hasNewSACKInfo = true + rcvdSeg.hasNewSACKInfo = true } } s.SetPipe() } // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(seg) + rtx := s.checkDuplicateAck(rcvdSeg) // Stash away the current window size. - s.sndWnd = seg.window + s.sndWnd = rcvdSeg.window - ack := seg.ackNumber + ack := rcvdSeg.ackNumber // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. - if s.zeroWindowProbing && seg.window > 0 && + if s.zeroWindowProbing && rcvdSeg.window > 0 && (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { s.disableZeroWindowProbing() } @@ -1344,10 +1348,10 @@ func (s *sender) handleRcvdSegment(seg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. - elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond s.updateRTO(elapsed) } @@ -1380,6 +1384,11 @@ func (s *sender) handleRcvdSegment(seg *segment) { s.writeNext = seg.Next() } + // Update the RACK fields if SACK is enabled. + if s.ep.sackPermitted { + s.rc.Update(seg, rcvdSeg, s.ep.tsOffset) + } + s.writeList.Remove(seg) // if SACK is enabled then Only reduce outstanding if @@ -1435,7 +1444,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { s.sendData() } } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go new file mode 100644 index 000000000..e03f101e8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" +) + +// TestRACKUpdate tests the RACK related fields are updated when an ACK is +// received on a SACK enabled connection. +func TestRACKUpdate(t *testing.T) { + const maxPayload = 10 + const tsOptionSize = 12 + const maxTCPOptionSize = 40 + + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + var xmitTime time.Time + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.RACKState is what we expect. + if state.Sender.RACKState.XmitTime.Before(xmitTime) { + t.Fatalf("RACK transmit time failed to update when an ACK is received") + } + + gotSeq := state.Sender.RACKState.EndSequence + wantSeq := state.Sender.SndNxt + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } + + if state.Sender.RACKState.RTT == 0 { + t.Fatalf("RACK RTT failed to update when an ACK is received") + } + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + xmitTime = time.Now() + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + c.SendAck(790, bytesRead) + time.Sleep(200 * time.Millisecond) +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 99521f0c1..ef7f5719f 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) + opt := tcpip.TCPSACKEnabled(enable) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 169adb16b..5b504d0d1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -74,8 +75,8 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted) + if err := ep.LastError(); err != tcpip.ErrAborted { + t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) } // Call Connect again to retreive the handshake failure status @@ -146,6 +147,24 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { } } +func TestCloseWithoutConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + c.EP.Close() + + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + func TestTCPSegmentsSentIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -222,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) { } } +// TestTCPResetsSentNoICMP confirms that we don't get an ICMP +// DstUnreachable packet when we try send a packet which is not part +// of an active session. +func TestTCPResetsSentNoICMP(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + + // Send a SYN request for a closed port. This should elicit an RST + // but NOT an ICMPv4 DstUnreachable packet. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive whatever comes back. + b := c.GetPacket() + ipHdr := header.IPv4(b) + if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { + t.Errorf("unexpected protocol, got = %d, want = %d", got, want) + } + + // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. + sent := stats.ICMP.V4PacketsSent + if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { + t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) + } +} + // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates // a RST if an ACK is received on the listening socket for which there is no // active handshake in progress and we are not using SYN cookies. @@ -273,12 +324,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -291,16 +342,16 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Lower stackwide TIME_WAIT timeout so that the reservations // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) } c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -330,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -414,8 +465,9 @@ func TestConnectResetAfterClose(t *testing.T) { // Set TCPLinger to 3 seconds so that sockets are marked closed // after 3 second in FIN_WAIT2 state. tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -428,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -470,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence number // of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -488,8 +540,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -509,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -545,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -592,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -613,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -673,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -725,135 +778,234 @@ func TestSimpleReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) } -// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when -// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when +// creating a new active TCP socket. It should be present in the sent TCP // SYN segment. -func TestUserSuppliedMSSOnConnectV4(t *testing.T) { +func TestUserSuppliedMSSOnConnect(t *testing.T) { const mtu = 5000 - const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS int - expMSS uint16 + + ips := []struct { + name string + createEP func(*context.Context) + connectAddr tcpip.Address + checker func(*testing.T, *context.Context, uint16, int) + maxMSS uint16 }{ { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + connectAddr: context.TestAddr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(true) + }, + connectAddr: context.TestV6Addr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - c.Create(-1) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - // Start connection attempt to IPv4 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) - } + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - // Receive SYN packet with our user supplied MSS. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} + if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } + + // Receive SYN packet with our user supplied MSS. + ip.checker(t, c, test.expMSS, ws) + }) + } }) } } -// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when -// creating a new active IPv6 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV6(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 +// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used +// when completing the handshake for a new TCP connection from a TCP +// listening socket. It should be present in the sent TCP SYN-ACK segment. +func TestUserSuppliedMSSOnListenAccept(t *testing.T) { + const ( + nonSynCookieAccepts = 2 + totalAccepts = 4 + mtu = 5000 + ) + + ips := []struct { + name string + createEP func(*context.Context) + sendPkt func(*context.Context, *context.Headers) + checker func(*testing.T, *context.Context, uint16, uint16) + maxMSS uint16 }{ { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendPacket(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(false) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendV6Packet(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - c.CreateV6Endpoint(true) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the SynRcvd threshold to force a syn cookie based accept to happen. + opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + } - // Start connection attempt to IPv6 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) - } + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - // Receive SYN packet with our user supplied MSS. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + bindAddr := tcpip.FullAddress{Port: context.StackPort} + if err := c.EP.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s:", bindAddr, err) + } + + if err := c.EP.Listen(totalAccepts); err != nil { + t.Fatalf("Listen(%d): %s:", totalAccepts, err) + } + + // The first nonSynCookieAccepts packets sent will trigger a gorooutine + // based accept. The rest will trigger a cookie based accept. + for i := 0; i < totalAccepts; i++ { + // Send a SYN requests. + iss := seqnum.Value(i) + srcPort := context.TestPort + uint16(i) + ip.sendPkt(c, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + ip.checker(t, c, srcPort, test.expMSS) + } + }) + } }) } } - func TestSendRstOnListenerRxSynAckV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -879,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxSynAckV6(t *testing.T) { @@ -907,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, @@ -944,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, @@ -982,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestSendRstOnListenerRxAckV4(t *testing.T) { @@ -1011,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxAckV6(t *testing.T) { @@ -1039,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestListenShutdown tests for the listening endpoint replying with RST @@ -1155,8 +1307,8 @@ func TestTOSV4(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1204,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1232,7 +1384,9 @@ func TestConnectBindToDevice(t *testing.T) { c.Create(-1) bindToDevice := tcpip.BindToDeviceOption(test.device) - c.EP.SetSockOpt(bindToDevice) + if err := c.EP.SetSockOpt(&bindToDevice); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) @@ -1276,68 +1430,91 @@ func TestConnectBindToDevice(t *testing.T) { } } -func TestRstOnSynSent(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() +func TestSynSent(t *testing.T) { + for _, test := range []struct { + name string + reset bool + }{ + {"RstOnSynSent", true}, + {"CloseOnSynSent", false}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() - // Create an endpoint, don't handshake because we want to interfere with the - // handshake process. - c.Create(-1) + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) - // Start connection attempt. - waitEntry, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) - } + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + } - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) - // Ensure that we've reached SynSent state - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - // Send a packet with a proper ACK and a RST flag to cause the socket - // to Error and close out - iss := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagRst | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) + if test.reset { + // Send a packet with a proper ACK and a RST flag to cause the socket + // to error and close out. + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + } else { + c.EP.Close() + } - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatal("timed out waiting for packet to arrive") - } + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) - } + if test.reset { + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) + } + } else { + if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted) + } + } - // Due to the RST the endpoint should be in an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + }) } } @@ -1370,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1421,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1432,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) + rcvBufSz := math.MaxUint16 + c.CreateConnected(789, 30000, rcvBufSz) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) @@ -1454,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1475,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1495,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(793), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1537,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1552,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1606,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1620,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { @@ -1639,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence // number of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), + checker.TCPSeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1685,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, 10) + const rcvBufSz = 10 + c.CreateConnected(789, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1696,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("Read failed: %s", err) } - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies + // the provided buffer value by tcp.SegOverheadFactor to calculate the actual + // receive buffer size. + data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) + for i := range data { + data[i] = byte(i % 255) + } c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -1718,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) @@ -1744,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), + checker.TCPWindow(10), ), ) } @@ -1756,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + // Start off with a certain receive buffer then cut it in half and verify that + // the right edge of the window does not shrink. + // NOTE: Netstack doubles the value specified here. + rcvBufSize := 65536 + iss := seqnum.Value(789) + // Enable window scaling with a scale of zero from our end. + c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1770,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ + // Send a 1 byte payload so that we can record the current receive window. + // Send a payload of half the size of rcvBufSize. + seqNum := iss.Add(1) + payload := []byte{1} + c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1789,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("Timed out waiting for data to arrive") } - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), + // Read the 1 byte payload we just sent. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if got, want := payload, v; !bytes.Equal(got, want) { + t.Fatalf("got data: %v, want: %v", got, want) + } + + seqNum = seqNum.Add(1) + // Verify that the ACK does not shrink the window. + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), ), ) + // Stash the initial window. + initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + // Now shrink the receive buffer to half its original size. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) + } - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ + data := generateRandomPayload(t, rcvBufSize) + // Send a payload of half the size of rcvBufSize. + c.SendPacket(data[:rcvBufSize/2], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") + // Verify that the ACK does not shrink the window. + pkt = c.GetPacket() + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { + t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } + // Send another payload of half the size of rcvBufSize. This should fill up the + // socket receive buffer and we should see a zero window. + c.SendPacket(data[rcvBufSize/2:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + // Receive data and check it. - read := make([]byte, 0, 10) + read := make([]byte, 0, rcvBufSize) for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { @@ -1842,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data = %v, want = %v", read, data) } - // Check that we get an ACK for the newly non-zero window, which is the - // new size. + // Check that we get an ACK for the newly non-zero window, which is the new + // receive buffer size we set after the connection was established. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), + checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), ) } @@ -1875,8 +2109,8 @@ func TestSimpleSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -1917,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -1939,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -1979,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2018,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2054,19 +2288,20 @@ func TestScaledWindowAccept(t *testing.T) { } // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2084,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2135,12 +2370,12 @@ func TestNonScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2165,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2180,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ + // Set the buffer size such that a window scale of 5 will be used. + const bufSz = 65535 * 10 + const ws = uint32(5) + c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) // Write chunks of 50000 bytes. - remain := wnd + remain := 0 sent := 0 data := make([]byte, 50000) - for remain > len(data) { + // Keep writing till the window drops below len(data). + for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -2201,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Don't reduce window to zero here. + if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { + remain = wnd << ws + break + } } // Make the window non-zero, but the scaled window zero. - if remain >= 16 { + for remain >= 16 { data = data[:remain-15] c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, @@ -2226,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Since the receive buffer is split between window advertisement and + // application data buffer the window does not always reflect the space + // available and actual space available can be a bit more than what is + // advertised in the window. + wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) + if wnd == 0 { + break + } + remain = wnd << ws } - // Read at least 1MSS of data. An ack should be sent in response to that. + // Read at least 2MSS of data. An ack should be sent in response to that. + // Since buffer space is now split in half between window and application + // data we need to read more than 1 MSS(65536) of data for a non-zero window + // update to be sent. For 1MSS worth of window to be available we need to + // read at least 128KB. Since our segments above were 50KB each it means + // we need to read at 3 packets. sz := 0 - for sz < defaultMTU { + for sz < defaultMTU*2 { v, _, err := c.EP.Read(nil) if err != nil { t.Fatalf("Read failed: %s", err) @@ -2253,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(sz>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2322,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize+1), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+uint32(i)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2345,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+11), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+11), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2393,8 +2646,8 @@ func TestDelay(t *testing.T) { checker.PayloadLen(len(want)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2440,8 +2693,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2463,8 +2716,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2525,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2577,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2698,12 +2951,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2725,8 +2978,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(0) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -2753,12 +3007,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2819,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - const wndScale = 2 + const wndScale = 3 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } @@ -2854,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), + checker.TCPSeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -2875,16 +3129,16 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) // Wait for connection to be established. select { case <-ch: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + if err := c.EP.LastError(); err != nil { + t.Fatalf("Connect failed: %s", err) } case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -3004,8 +3258,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { defer c.Cleanup() const numRetries = 2 - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { - t.Fatalf("could not set protocol option MaxRetries.\n") + opt := tcpip.TCPMaxRetriesOption(numRetries) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3064,8 +3319,9 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + opt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3095,6 +3351,63 @@ func TestMaxRTO(t *testing.T) { } } +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3110,8 +3423,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3131,8 +3444,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3153,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3164,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3185,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3209,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3234,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3256,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3284,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3303,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3323,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3344,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3368,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3393,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3409,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3430,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3455,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3476,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3491,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3507,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3599,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.PayloadLen((1<<scale)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3738,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3765,8 +4078,9 @@ func TestReadAfterClosedState(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -3788,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3813,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(791+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3986,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func TestDefaultBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4005,11 +4319,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4022,11 +4340,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4041,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) { func TestMinMaxBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4053,22 +4375,28 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - // Set values below the min. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { + // Set values below the min/2. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } @@ -4079,19 +4407,21 @@ func TestMinMaxBufferSizes(t *testing.T) { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) + // Values above max are capped at max and then doubled. + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) + // Values above max are capped at max and then doubled. + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -4124,16 +4454,15 @@ func TestBindToDeviceOption(t *testing.T) { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr) + if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } bindToDevice := tcpip.BindToDeviceOption(88888) if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %s, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) + t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) + } else if bindToDevice != testAction.getBindToDevice { + t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) } @@ -4141,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) { func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() @@ -4214,7 +4543,7 @@ func TestSelfConnect(t *testing.T) { } <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { + if err := ep.LastError(); err != nil { t.Fatalf("Connect failed: %s", err) } @@ -4428,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) { checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), + checker.TCPSeqNum(seqNum), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4520,8 +4849,8 @@ func TestStackSetCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption @@ -4553,12 +4882,12 @@ func TestStackAvailableCongestionControl(t *testing.T) { s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption + var aCC tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) } } @@ -4569,18 +4898,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") + aCC := tcpip.TCPAvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption + var cc tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) } } @@ -4609,20 +4938,20 @@ func TestEndpointSetCongestionControl(t *testing.T) { var oldCC tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err) + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) } if connected { c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) } - if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err) + if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { + t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err) + t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) } got, want := cc, oldCC @@ -4634,7 +4963,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { want = tc.cc } if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) + t.Fatalf("got congestion control = %+v, want = %+v", got, want) } }) } @@ -4644,8 +4973,8 @@ func TestEndpointSetCongestionControl(t *testing.T) { func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -4655,11 +4984,23 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) + } + keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) + } c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { + t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) + } // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -4668,8 +5009,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4702,8 +5043,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4714,8 +5055,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) @@ -4740,8 +5081,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(next-1)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4767,8 +5108,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -4808,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -4852,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -4925,12 +5266,12 @@ func TestListenBacklogFull(t *testing.T) { defer c.WQ.EventUnregister(&we) for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -4942,7 +5283,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -4954,12 +5295,12 @@ func TestListenBacklogFull(t *testing.T) { // Now a new handshake must succeed. executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -4984,6 +5325,8 @@ func TestListenBacklogFull(t *testing.T) { func TestListenNoAcceptNonUnicastV4(t *testing.T) { multicastAddr := tcpip.Address("\xe0\x00\x01\x02") otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() tests := []struct { name string @@ -4991,53 +5334,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { dstAddr tcpip.Address }{ { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, }, { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, }, { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, }, { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, }, { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, }, { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, }, { - "DestOurMulticast", - context.TestAddr, - multicastAddr, + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, }, { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, }, } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5078,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5086,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") tests := []struct { name string @@ -5137,11 +5486,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5182,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5230,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5266,12 +5611,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5295,8 +5640,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(1) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create TCP endpoint. @@ -5342,12 +5688,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5358,7 +5704,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5407,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5428,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.AckNum(uint32(irs) + 1), - checker.SeqNum(uint32(iss + 1)), + checker.TCPAckNum(uint32(irs) + 1), + checker.TCPSeqNum(uint32(iss + 1)), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5447,7 +5793,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err != nil && err != tcpip.ErrWouldBlock { t.Fatalf("Accept failed: %s", err) @@ -5462,7 +5808,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5520,12 +5866,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5590,12 +5936,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5643,12 +5989,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - aep, _, err = ep.Accept() + aep, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5696,13 +6042,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -5721,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { time.Sleep(latency) rawEP.SendPacketWithTS([]byte{1}, tsVal) - // Verify that the ACK has the expected window. - wantRcvWnd := receiveBufferSize - wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 + payloadSize := receiveBufferSize * 2 + b := make([]byte, int(payloadSize)) + worker := (c.EP).(interface { StopWork() ResumeWork() @@ -5739,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Stop the worker goroutine. worker.StopWork() - start := offset - end := offset + payloadSize + start := 0 + end := payloadSize / 2 packetsSent := 0 for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + packetEnd := start + mss + if start+mss > end { + packetEnd = end + } + rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) packetsSent++ } @@ -5751,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // are waiting to be read. worker.ResumeWork() - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } + // Since we sent almost the full receive buffer worth of data (some may have + // been dropped due to segment overheads), we should get a zero window back. + pkt = c.GetPacket() + tcpHdr := header.TCP(header.IPv4(pkt).Payload()) + gotRcvWnd := tcpHdr.WindowSize() + wantAckNum := tcpHdr.AckNumber() + if got, want := int(gotRcvWnd), 0; got != want { + t.Fatalf("got rcvWnd: %d, want: %d", got, want) } - rawEP.VerifyACKRcvWnd(0) time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. + // Verify that sending more data when receiveBuffer is exhausted. rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { @@ -5786,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Verify that we receive a non-zero window update ACK. When running // under thread santizer this test can end up sending more than 1 // ack, 1 for the non-zero window - p = c.GetPacket() + p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + checker.TCPAckNum(uint32(wantAckNum)), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { return } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) + // We use 10% here as the error margin upwards as the initial window we + // got was afer 1 segment was already in the receive buffer queue. + tolerance := 1.1 + if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { + t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) } }, )) } -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. +// This test verifies that the advertised window is auto-tuned up as the +// application is reading the data that is being received. func TestReceiveBufferAutoTuning(t *testing.T) { const mtu = 1500 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize @@ -5812,26 +6160,33 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize + tsVal := uint32(rawEP.TSVal) + rawEP.NextSeqNum-- + rawEP.SendPacketWithTS(nil, tsVal) + rawEP.NextSeqNum++ + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { return uint16(rcvWnd >> uint16(c.WindowScale)) } @@ -5848,14 +6203,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) { StopWork() ResumeWork() }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false latency := 1 * time.Millisecond - for i := 0; !done; i++ { + for i := 0; i < 5; i++ { tsVal++ // Stop the worker goroutine. @@ -5877,15 +6226,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Give 1ms for the worker to process the packets. time.Sleep(1 * time.Millisecond) - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break } + lastACK = pkt + } + if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { + t.Fatalf("advertised window got: %d, want <= %d", got, want) } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) // Now read all the data from the endpoint and invoke the // moderation API to allow for receive buffer auto-tuning @@ -5910,35 +6264,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ - if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied + rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true + pkt := c.GetPacket() + curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // If thew new current window is close maxReceiveBufferSize then terminate + // the loop. This can happen before all iterations are done due to timing + // differences when running the test. + if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { + break } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied // Increase the latency after first two iterations to // establish a low RTT value in the receiver since it // only tracks the lowest value. This ensures that when @@ -5951,6 +6290,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) { offset += payloadSize payloadSize *= 2 } + // Check that at the end of our iterations the receive window grew close to the maximum + // permissible size of maxReceiveBufferSize/2 + if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { + t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) + } + } func TestDelayEnabled(t *testing.T) { @@ -5959,7 +6304,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcp.DelayEnabled + delayEnabled tcpip.TCPDelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -5967,17 +6312,17 @@ func TestDelayEnabled(t *testing.T) { } { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) } checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcp.DelayEnabled + var gotDelayEnabled tcpip.TCPDelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } @@ -6009,24 +6354,27 @@ func TestTCPLingerTimeout(t *testing.T) { tcpLingerTimeout time.Duration want time.Duration }{ - {"NegativeLingerTimeout", -123123, 0}, - {"ZeroLingerTimeout", 0, 0}, + {"NegativeLingerTimeout", -123123, -1}, + // Zero is treated same as the stack's default TCP_LINGER2 timeout. + {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, // Values > stack's TCPLingerTimeout are capped to the stack's // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { - t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) } - var v tcpip.TCPLingerTimeoutOption + + v = 0 if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + t.Fatalf("GetSockOpt(&%T) = %s", v, err) } if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + t.Fatalf("got linger timeout = %s, want = %s", got, want) } }) } @@ -6080,12 +6428,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6099,8 +6447,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6117,8 +6465,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now send a RST and this should be ignored and not @@ -6146,8 +6494,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6199,12 +6547,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6218,8 +6566,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6236,8 +6584,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Out of order ACK should generate an immediate ACK in @@ -6253,8 +6601,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6306,12 +6654,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6325,8 +6673,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6343,8 +6691,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Send a SYN request w/ sequence number lower than @@ -6389,12 +6737,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.SendPacket(nil, ackHeaders) // Try to accept the connection. - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6412,8 +6760,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 @@ -6462,12 +6811,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6481,8 +6830,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6499,8 +6848,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) time.Sleep(2 * time.Second) @@ -6514,8 +6863,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Sleep for 4 seconds so at this point we are 1 second past the @@ -6543,8 +6892,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { @@ -6562,8 +6911,9 @@ func TestTCPCloseWithData(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } wq := &waiter.Queue{} @@ -6611,12 +6961,12 @@ func TestTCPCloseWithData(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6642,8 +6992,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now write a few bytes and then close the endpoint. @@ -6661,8 +7011,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -6676,8 +7026,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), + checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.TCPAckNum(uint32(iss+2)), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) // First send a partial ACK. @@ -6722,8 +7072,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -6743,7 +7093,10 @@ func TestTCPUserTimeout(t *testing.T) { // expired. initRTO := 1 * time.Second userTimeout := initRTO / 2 - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + v := tcpip.TCPUserTimeoutOption(userTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) + } // Send some data and wait before ACKing it. view := buffer.NewView(3) @@ -6756,8 +7109,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -6791,8 +7144,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -6817,18 +7170,31 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) - c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) + } + keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) + } + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { + t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) + } // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent // the second one should cause the connection to be // closed due to userTimeout being hit. - userTimeout := 1 * keepAliveInterval - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&userTimeout); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) + } // Check that the connection is still alive. if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { @@ -6840,8 +7206,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -6866,8 +7232,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -6883,9 +7249,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } } -func TestIncreaseWindowOnReceive(t *testing.T) { +func TestIncreaseWindowOnRead(t *testing.T) { // This test ensures that the endpoint sends an ack, - // after recv() when the window grows to more than 1 MSS. + // after read() when the window grows by more than 1 MSS. c := context.New(t, defaultMTU) defer c.Cleanup() @@ -6894,10 +7260,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -6910,46 +7275,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) { }) sent += len(data) remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Break once the window drops below defaultMTU/2 + if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { + break + } } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // We now have < 1 MSS in the buffer space. Read the data! An - // ack should be sent in response to that. The window was not - // zero, but it grew to larger than MSS. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) - } - - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) + // We now have < 1 MSS in the buffer space. Read at least > 2 MSS + // worth of data as receive buffer space + read := 0 + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + for read < defaultMTU*2 { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + read += len(v) } - // After reading two packets, we surely crossed MSS. See the ack: + // After reading > MSS worth of data, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -6966,10 +7328,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -6983,38 +7344,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { sent += len(data) remain -= len(data) - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), ) } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - // After reading two packets, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7035,14 +7387,15 @@ func TestTCPDeferAccept(t *testing.T) { } const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7058,14 +7411,14 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7073,8 +7426,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestTCPDeferAcceptTimeout(t *testing.T) { @@ -7092,14 +7445,15 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { } const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7110,7 +7464,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) // Send data. This should result in an acceptable endpoint. c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ @@ -7126,14 +7480,14 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7142,8 +7496,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestResetDuringClose(t *testing.T) { @@ -7168,8 +7522,8 @@ func TestResetDuringClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(irs.Add(1))), - checker.AckNum(uint32(iss.Add(5))))) + checker.TCPSeqNum(uint32(irs.Add(1))), + checker.TCPAckNum(uint32(iss.Add(5))))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7199,3 +7553,65 @@ func TestResetDuringClose(t *testing.T) { wg.Wait() } + +func TestStackTimeWaitReuse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + s := c.Stack() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) + } + if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } +} + +func TestSetStackTimeWaitReuse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + s := c.Stack() + testCases := []struct { + v int + err *tcpip.Error + }{ + {int(tcpip.TCPTimeWaitReuseDisabled), nil}, + {int(tcpip.TCPTimeWaitReuseGlobal), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + } + + for _, tc := range testCases { + opt := tcpip.TCPTimeWaitReuseOption(tc.v) + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) + if got, want := err, tc.err; got != want { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) + } + if tc.err != nil { + continue + } + + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) + } + + if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } + } +} + +// generateRandomPayload generates a random byte slice of the specified length +// causing a fatal test failure if it is unable to do so. +func generateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 8edbff964..0f9ed06cd 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -158,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.PayloadLen(len(data)+header.TCPMinimumSize+12), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(true, 0, tsVal+1), ), @@ -180,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) @@ -192,8 +194,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -217,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(false, 0, 0), ), @@ -235,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of + // that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 06fde2a79..faf51ef95 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -53,11 +53,11 @@ const ( TestPort = 4096 // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" // TestV6Addr is the source address for packets sent to the stack via // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // StackV4MappedAddr is StackAddr as a mapped v6 address. StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr @@ -73,6 +73,18 @@ const ( testInitialSequenceNumber = 789 ) +// StackAddrWithPrefix is StackAddr with its associated prefix length. +var StackAddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackAddr, + PrefixLen: 24, +} + +// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. +var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackV6Addr, + PrefixLen: header.IIDOffsetInIPv6Address * 8, +} + // Headers is used to represent the TCP header fields when building a // new packet. type Headers struct { @@ -133,30 +145,39 @@ type Context struct { // WindowScale is the expected window scale in SYN packets sent by // the stack. WindowScale uint8 + + // RcvdWindowScale is the actual window scale sent by the stack in + // SYN/SYN-ACK. + RcvdWindowScale uint8 } // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) + const sendBufferSize = 1 << 20 // 1 MiB + const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) } // Increase minimum RTO in tests to avoid test flakes due to early // retransmit in case the test executors are overloaded and cause timers // to fire earlier than expected. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { - t.Fatalf("failed to set stack-wide minRTO: %s", err) + minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) } // Some of the congestion control tests send up to 640 packets, we so @@ -179,12 +200,20 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -202,7 +231,7 @@ func New(t *testing.T, mtu uint32) *Context { t: t, s: s, linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), + WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), } } @@ -236,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) { c.CheckNoPacketTimeout(errMsg, 1*time.Second) } -// GetPacket reads a packet from the link layer endpoint and verifies +// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies // that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { +// addresses. If no packet is received in the specified timeout it will return +// nil. +func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { - c.t.Fatalf("Packet wasn't written out") return nil } @@ -255,8 +283,16 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -266,6 +302,21 @@ func (c *Context) GetPacket() []byte { return b } +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. +func (c *Context) GetPacket() []byte { + c.t.Helper() + + p := c.GetPacketWithTimeout(5 * time.Second) + if p == nil { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + return p +} + // GetPacketNonBlocking reads a packet from the link layer endpoint // and verifies that it is an IPv4 packet with the expected source // and destination address. If no packet is available it will return @@ -282,15 +333,23 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { +func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { @@ -316,9 +375,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -372,26 +432,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: s, }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegment(payload, h), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacketWithAddrs builds and sends a TCP segment(with the provided payload // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendAck sends an ACK packet. @@ -441,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.PayloadLen(size+header.TCPMinimumSize+optlen), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -468,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -512,9 +575,8 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -564,9 +626,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) } // CreateConnected creates a connected TCP endpoint. @@ -607,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) } tcpHdr := header.TCP(header.IPv4(b).Payload()) + synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ @@ -624,15 +688,15 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) checker.TCP( checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) // Wait for connection to be established. select { case <-notifyCh: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + if err := c.EP.LastError(); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -642,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } + c.RcvdWindowScale = uint8(synOpts.WS) c.Port = tcpHdr.SourcePort() } @@ -713,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) } -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { +// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches +// the provided tsVal as well as returns the original packet. +func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { + r.C.t.Helper() // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPTimestampChecker(true, 0, tsVal), ), ) @@ -731,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) opts := tcpSeg.ParsedOptions() r.RecentTS = opts.TSVal + return ackPacket +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + r.C.t.Helper() + _ = r.VerifyAndReturnACKWithTS(tsVal) } // VerifyACKRcvWnd verifies that the window advertised by the incoming ACK // matches the provided rcvWnd. func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { + r.C.t.Helper() ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), + checker.TCPWindow(rcvWnd), ), ) } @@ -762,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPSACKBlockChecker(sackBlocks), ), ) @@ -855,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * tcpCheckers := []checker.TransportChecker{ checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), + checker.TCPSeqNum(uint32(c.IRS) + 1), + checker.TCPAckNum(uint32(iss) + 1), } // Verify that tsEcr of ACK packet is wantOptions.TSVal if the @@ -876,8 +951,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -892,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Mark in context that timestamp option is enabled for this endpoint. c.TimeStampEnabled = true - + c.RcvdWindowScale = uint8(synOptions.WS) return &RawEndpoint{ C: c, SrcPort: tcpSeg.DestinationPort(), @@ -943,12 +1017,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { c.t.Fatalf("Accept failed: %v", err) } @@ -985,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 offset += header.EncodeMSSOption(uint32(maxPayload), opts) @@ -1023,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // are present. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) + rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) c.IRS = seqnum.Value(tcp.SequenceNumber()) tcpCheckers := []checker.TransportChecker{ checker.SrcPort(StackPort), checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), + checker.TCPAckNum(uint32(iss) + 1), checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), } @@ -1072,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // Send ACK. c.SendPacket(nil, ackHeaders) + c.RcvdWindowScale = uint8(rcvdSynOptions.WS) c.Port = StackPort return &RawEndpoint{ @@ -1091,7 +1168,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled + var v tcpip.TCPSACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index b5d2d0ba6..c78549424 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cae29fbff..d57ed5d79 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -139,7 +139,7 @@ type endpoint struct { // multicastMemberships that need to be remvoed when the endpoint is // closed. Protected by the mu mutex. - multicastMemberships []multicastMembership + multicastMemberships map[multicastMembership]struct{} // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -154,6 +154,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // +stateify savable @@ -182,12 +185,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // TTL=1. // // Linux defaults to TTL=1. - multicastTTL: 1, - multicastLoop: true, - rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, - state: StateInitial, - uniqueID: s.UniqueID(), + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + multicastMemberships: make(map[multicastMembership]struct{}), + state: StateInitial, + uniqueID: s.UniqueID(), } // Override with stack defaults. @@ -209,7 +213,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) takeLastError() *tcpip.Error { +func (e *endpoint) LastError() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -237,10 +241,10 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } - for _, mem := range e.multicastMemberships { + for mem := range e.multicastMemberships { e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } - e.multicastMemberships = nil + e.multicastMemberships = make(map[multicastMembership]struct{}) // Close the receive list and drain it. e.rcvMu.Lock() @@ -268,7 +272,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {} // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if err := e.takeLastError(); err != nil { + if err := e.LastError(); err != nil { return buffer.View{}, tcpip.ControlMessages{}, err } @@ -411,7 +415,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - if err := e.takeLastError(); err != nil { + if err := e.LastError(); err != nil { return 0, nil, err } @@ -483,10 +487,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - if to.Addr == header.IPv4Broadcast && !e.broadcast { - return 0, nil, tcpip.ErrBroadcastDisabled - } - dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -503,6 +503,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c resolve = route.Resolve } + if !e.broadcast && route.IsOutboundBroadcast() { + return 0, nil, tcpip.ErrBroadcastDisabled + } + if route.IsResolutionRequired() { if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { @@ -612,6 +616,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -676,9 +687,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case tcpip.MulticastInterfaceOption: + case *tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() @@ -714,7 +725,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastNICID = nic e.multicastAddr = addr - case tcpip.AddMembershipOption: + case *tcpip.AddMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { return tcpip.ErrInvalidOptionValue } @@ -745,19 +756,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - for _, mem := range e.multicastMemberships { - if mem == memToInsert { - return tcpip.ErrPortInUse - } + if _, ok := e.multicastMemberships[memToInsert]; ok { + return tcpip.ErrPortInUse } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } - e.multicastMemberships = append(e.multicastMemberships, memToInsert) + e.multicastMemberships[memToInsert] = struct{}{} - case tcpip.RemoveMembershipOption: + case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { return tcpip.ErrInvalidOptionValue } @@ -779,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() - for i, mem := range e.multicastMemberships { - if mem == memToRemove { - memToRemoveIndex = i - break - } - } - if memToRemoveIndex == -1 { + if _, ok := e.multicastMemberships[memToRemove]; !ok { return tcpip.ErrBadLocalAddress } @@ -798,17 +800,24 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return err } - e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + delete(e.multicastMemberships, memToRemove) - case tcpip.BindToDeviceOption: - id := tcpip.NICID(v) + case *tcpip.BindToDeviceOption: + id := tcpip.NICID(*v) if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -906,6 +915,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.MulticastTTLOption: e.mu.Lock() v := int(e.multicastTTL) @@ -946,10 +959,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case tcpip.ErrorOption: - return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -963,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() + case *tcpip.LingerOption: + e.mu.RLock() + *o = e.linger + e.mu.RUnlock() + default: return tcpip.ErrUnknownProtocolOption } @@ -972,13 +988,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { - // Allocate a buffer for the UDP header. - hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: data, + }) + pkt.Owner = owner - // Initialize the header. - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ SrcPort: localPort, DstPort: remotePort, @@ -1005,12 +1025,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Protocol: ProtocolNumber, TTL: ttl, TOS: tos, - }, &stack.PacketBuffer{ - Header: hdr, - Data: data, - TransportHeader: buffer.View(udp), - Owner: owner, - }); err != nil { + }, pkt); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -1208,13 +1223,13 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { return id, e.bindToDevice, err } @@ -1354,11 +1369,27 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { return result } +// verifyChecksum verifies the checksum unless RX checksum offload is enabled. +// On IPv4, UDP checksum is optional, and a zero value means the transmitter +// omitted the checksum generation (RFC768). +// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). +func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + return hdr.CalculateChecksum(xsum) == 0xffff + } + return true +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() @@ -1366,28 +1397,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - // Verify checksum unless RX checksum offload is enabled. - // On IPv4, UDP checksum is optional, and a zero value means - // the transmitter omitted the checksum generation (RFC768). - // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - if hdr.CalculateChecksum(xsum) != 0xffff { - // Checksum Error. - e.stack.Stats().UDP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - return - } + if !verifyChecksum(r, hdr, pkt) { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() @@ -1420,15 +1440,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Save any useful information from the network header to the packet. switch r.NetProto { case header.IPv4ProtocolNumber: - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.RemoteAddress - packet.packetInfo.NIC = r.NICID() + packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: - packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS() } - packet.timestamp = e.stack.NowNanoseconds() + // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast + // address. packetInfo.LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.LocalAddress + packet.packetInfo.NIC = r.NICID() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 851e6b635..858c99a45 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - for _, m := range e.multicastMemberships { + for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 0e7464e3a..da5b1deb2 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -12,18 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package udp contains the implementation of the UDP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing udp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package udp contains the implementation of the UDP transport protocol. package udp import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/waiter" @@ -49,6 +45,7 @@ const ( ) type protocol struct { + stack *stack.Stack } // Number returns the udp protocol number. @@ -57,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid udp packet size. @@ -79,131 +76,30 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } -// HandleUnknownDestinationPacket handles packets targeted at this protocol but -// that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - hdr := header.UDP(pkt.TransportHeader) +// HandleUnknownDestinationPacket handles packets that are targeted at this +// protocol but don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. - - // Only send ICMP error if the address is not a multicast/broadcast - // v4/v6 address or the source is not the unspecified address. - // - // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { - return true + return stack.UnknownDestinationPacketMalformed } - // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination - // Unreachable messages with code: - // - // 2 (Protocol Unreachable), when the designated transport protocol - // is not supported; or - // - // 3 (Port Unreachable), when the designated transport protocol - // (e.g., UDP) is unable to demultiplex the datagram but has no - // protocol mechanism to inform the sender. - switch len(id.LocalAddress) { - case header.IPv4AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() - return true - } - // As per RFC 1812 Section 4.3.2.3 - // - // ICMP datagram SHOULD contain as much of the original - // datagram as possible without the length of the ICMP - // datagram exceeding 576 bytes - // - // NOTE: The above RFC referenced is different from the original - // recommendation in RFC 1122 where it mentioned that at least 8 - // bytes of the payload must be included. Today linux and other - // systems implement the] RFC1812 definition and not the original - // RFC 1122 requirement. - mtu := int(r.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - - // The buffers used by pkt may be used elsewhere in the system. - // For example, a raw or packet socket may use what UDP - // considers an unreachable destination. Thus we deep copy pkt - // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader...) - newHeader = append(newHeader, pkt.TransportHeader...) - payload := newHeader.ToVectorisedView() - payload.AppendView(pkt.Data.ToView()) - payload.CapLength(payloadLen) - - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4DstUnreachable) - pkt.SetCode(header.ICMPv4PortUnreachable) - pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, - }) - - case header.IPv6AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() - return true - } - - // As per RFC 4443 section 2.4 - // - // (c) Every ICMPv6 error message (type < 128) MUST include - // as much of the IPv6 offending (invoking) packet (the - // packet that caused the error) as possible without making - // the error message packet exceed the minimum IPv6 MTU - // [IPv6]. - mtu := int(r.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize - available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) - payload.Append(pkt.Data) - payload.CapLength(payloadLen) - - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) - pkt.SetType(header.ICMPv6DstUnreachable) - pkt.SetCode(header.ICMPv6PortUnreachable) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, - }) + if !verifyChecksum(r, hdr, pkt) { + r.Stack().Stats().UDP.ChecksumErrors.Increment() + return stack.UnknownDestinationPacketMalformed } - return true + + return stack.UnknownDestinationPacketUnhandled } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -215,17 +111,10 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Packet is too small - return false - } - pkt.TransportHeader = h - pkt.Data.TrimFront(header.UDPMinimumSize) - return true + return parse.UDP(pkt) } // NewProtocol returns a UDP transport protocol. -func NewProtocol() stack.TransportProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index db59eb5a0..cddedb686 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -276,8 +294,8 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) } @@ -370,8 +388,12 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { + c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() h := flow.header4Tuple(outgoing) checkers = append( @@ -385,21 +407,38 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw } // injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. -func (c *testContext) injectPacket(flow testFlow, payload []byte) { +// and injects it into the link endpoint. If badChecksum is true, the packet has +// a bad checksum in the UDP header. +func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) { c.t.Helper() h := flow.header4Tuple(incoming) if flow.isV4() { buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + if badChecksum { + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + } + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } else { buf := c.buildV6Packet(payload, &h) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + if badChecksum { + // Invalidate the UDP header checksum field (Unlike IPv4, zero is + // a valid checksum value for IPv6 so no need to avoid it). + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + } + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } } @@ -493,8 +532,8 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -504,7 +543,7 @@ func TestBindToDeviceOption(t *testing.T) { opts := stack.NICOptions{Name: "my_device"} if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) } // nicIDPtr is used instead of taking the address of NICID literals, which is @@ -528,16 +567,15 @@ func TestBindToDeviceOption(t *testing.T) { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) + if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } bindToDevice := tcpip.BindToDeviceOption(88888) if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %v, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) + t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) + } else if bindToDevice != testAction.getBindToDevice { + t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) } @@ -551,7 +589,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe c.t.Helper() payload := newPayload() - c.injectPacket(flow, payload) + c.injectPacket(flow, payload, false) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -593,12 +631,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("bad payload: got %x, want %x", v, payload) + c.t.Fatalf("got payload = %x, want = %x", v, payload) } // Run any checkers against the ControlMessages. @@ -659,7 +697,7 @@ func TestBindReservedPort(t *testing.T) { } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) } } @@ -672,7 +710,7 @@ func TestBindReservedPort(t *testing.T) { // We can't bind ipv4-any on the port reserved by the connected endpoint // above, since the endpoint is dual-stack. if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { @@ -769,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: tt.handleLocal, }) defer c.cleanup() @@ -786,16 +824,16 @@ func TestV4ReadSelfSource(t *testing.T) { h.srcAddr = h.dstAddr buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) } if _, _, err := c.ep.Read(nil); err != tt.wantErr { - t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) + t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr) } }) } @@ -836,8 +874,8 @@ func TestReadOnBoundToMulticast(t *testing.T) { // Join multicast group. ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) } // Check that we receive multicast packets but not unicast or broadcast @@ -872,6 +910,24 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -1237,6 +1293,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { } } +func TestReadIPPacketInfo(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedLocalAddr tcpip.Address + expectedDestAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedLocalAddr: stackAddr, + expectedDestAddr: stackAddr, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastAddr, + expectedDestAddr: multicastAddr, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: broadcastAddr, + expectedDestAddr: broadcastAddr, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedLocalAddr: stackV6Addr, + expectedDestAddr: stackV6Addr, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastV6Addr, + expectedDestAddr: multicastV6Addr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { + t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) + } + + testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: 1, + LocalAddr: test.expectedLocalAddr, + DestinationAddr: test.expectedDestAddr, + })) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1275,6 +1430,30 @@ func TestNoChecksum(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1292,19 +1471,19 @@ func TestTTL(t *testing.T) { if flow.isMulticast() { wantTTL = multicastTTL } else { - var p stack.NetworkProtocol + var p stack.NetworkProtocolFactory + var n tcpip.NetworkProtocolNumber if flow.isV4() { - p = ipv4.NewProtocol() + p = ipv4.NewProtocol + n = ipv4.ProtocolNumber } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - if err != nil { - t.Fatal(err) + p = ipv6.NewProtocol + n = ipv6.ProtocolNumber } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{p}, + }) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) wantTTL = ep.DefaultTTL() ep.Close() } @@ -1328,21 +1507,6 @@ func TestSetTTL(t *testing.T) { c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - if err != nil { - t.Fatal(err) - } - ep.Close() - testWrite(c, flow, checker.TTL(wantTTL)) }) } @@ -1365,7 +1529,7 @@ func TestSetTOS(t *testing.T) { } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { @@ -1526,19 +1690,17 @@ func TestMulticastInterfaceOption(t *testing.T) { } } - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %s", err) + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) } // Verify multicast interface addr and NIC were set correctly. // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %s", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) + c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) + } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { + c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) } }) } @@ -1562,21 +1724,33 @@ func TestV4UnknownDestination(t *testing.T) { // so that the final generated IPv4 packet is larger than // header.IPv4MinimumProcessableDatagramSize. largePayload bool + // badChecksum if true, will set an invalid checksum in the + // header. + badChecksum bool }{ - {unicastV4, true, false}, - {unicastV4, true, true}, - {multicastV4, false, false}, - {multicastV4, false, true}, - {broadcast, false, false}, - {broadcast, false, true}, - } + {unicastV4, true, false, false}, + {unicastV4, true, true, false}, + {unicastV4, false, false, true}, + {unicastV4, false, true, true}, + {multicastV4, false, false, false}, + {multicastV4, false, true, false}, + {broadcast, false, false, false}, + {broadcast, false, true, false}, + } + checksumErrors := uint64(0) for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { payload := newPayload() if tc.largePayload { payload = newMinPayload(576) } - c.injectPacket(tc.flow, payload) + c.injectPacket(tc.flow, payload, tc.badChecksum) + if tc.badChecksum { + checksumErrors++ + if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { + t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + } if !tc.icmpRequired { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -1595,9 +1769,8 @@ func TestV4UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1607,16 +1780,26 @@ func TestV4UnknownDestination(t *testing.T) { checker.ICMPv4Type(header.ICMPv4DstUnreachable), checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + // We need to compare the included data part of the UDP packet that is in + // the ICMP packet with the matching original data. icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) + incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize wantLen := len(payload) if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + // To work out the data size we need to simulate what the sender would + // have done. The wanted size is the total available minus the sum of + // the headers in the UDP AND ICMP packets, given that we know the test + // had only a minimal IP header but the ICMP sender will have allowed + // for a maximally sized packet header. + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + } - // In case of large payloads the IP packet may be truncated. Update + // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + // Add back the two headers within the payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) if got, want := len(origDgram.Payload()), wantLen; got != want { @@ -1642,19 +1825,31 @@ func TestV6UnknownDestination(t *testing.T) { // largePayload if true will result in a payload large enough to // create an IPv6 packet > header.IPv6MinimumMTU bytes. largePayload bool + // badChecksum if true, will set an invalid checksum in the + // header. + badChecksum bool }{ - {unicastV6, true, false}, - {unicastV6, true, true}, - {multicastV6, false, false}, - {multicastV6, false, true}, - } + {unicastV6, true, false, false}, + {unicastV6, true, true, false}, + {unicastV6, false, false, true}, + {unicastV6, false, true, true}, + {multicastV6, false, false, false}, + {multicastV6, false, true, false}, + } + checksumErrors := uint64(0) for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { payload := newPayload() if tc.largePayload { payload = newMinPayload(1280) } - c.injectPacket(tc.flow, payload) + c.injectPacket(tc.flow, payload, tc.badChecksum) + if tc.badChecksum { + checksumErrors++ + if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { + t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + } if !tc.icmpRequired { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -1673,9 +1868,8 @@ func TestV6UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1721,12 +1915,14 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Invalidate the packet length field in the UDP header by adding one. + + // Invalidate the UDP header length field. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { @@ -1779,74 +1975,38 @@ func TestShortHeader(t *testing.T) { copy(buf[header.IPv6MinimumSize:], udpHdr) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) } } -// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// TestBadChecksumErrors verifies if a checksum error is detected, // global and endpoint stats are incremented. -func TestIncrementChecksumErrorsV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Invalidate the checksum field in the UDP header by adding one. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestIncrementChecksumErrorsV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() +func TestBadChecksumErrors(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Invalidate the checksum field in the UDP header by adding one. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } } } @@ -1865,11 +2025,12 @@ func TestPayloadModifiedV4(t *testing.T) { payload := newPayload() h := unicastV4.header4Tuple(incoming) buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -1895,11 +2056,12 @@ func TestPayloadModifiedV6(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -1928,9 +2090,9 @@ func TestChecksumZeroV4(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv4MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 0 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -1959,9 +2121,9 @@ func TestChecksumZeroV6(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2059,3 +2221,193 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const nicID1 = 1 + + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + requiresBroadcastOpt bool + }{ + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + requiresBroadcastOpt: true, + }, + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + requiresBroadcastOpt: false, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + // TODO(gvisor.dev/issue/3938): Once we support marking a route as + // broadcast, this test should require the broadcast option to be set. + requiresBroadcastOpt: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) + } + defer ep.Close() + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + to := tcpip.FullAddress{ + Addr: test.remoteAddr, + Port: 80, + } + opts := tcpip.WriteOptions{To: &to} + expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + if !test.requiresBroadcastOpt { + expectedErrWithoutBcastOpt = nil + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + }) + } +} diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go index 8fed29ff5..70945f234 100644 --- a/pkg/test/criutil/criutil.go +++ b/pkg/test/criutil/criutil.go @@ -22,6 +22,9 @@ import ( "fmt" "os" "os/exec" + "path" + "regexp" + "strconv" "strings" "time" @@ -33,28 +36,44 @@ import ( type Crictl struct { logger testutil.Logger endpoint string + runpArgs []string cleanup []func() } -// resolvePath attempts to find binary paths. It may set the path to invalid, +// ResolvePath attempts to find binary paths. It may set the path to invalid, // which will cause the execution to fail with a sensible error. -func resolvePath(executable string) string { +func ResolvePath(executable string) string { + runtime, err := dockerutil.RuntimePath() + if err == nil { + // Check first the directory of the runtime itself. + if dir := path.Dir(runtime); dir != "" && dir != "." { + guess := path.Join(dir, executable) + if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 { + return guess + } + } + } + + // Try to find via the path. guess, err := exec.LookPath(executable) - if err != nil { - guess = fmt.Sprintf("/usr/local/bin/%s", executable) + if err == nil { + return guess } - return guess + + // Return a default path. + return fmt.Sprintf("/usr/local/bin/%s", executable) } // NewCrictl returns a Crictl configured with a timeout and an endpoint over // which it will talk to containerd. -func NewCrictl(logger testutil.Logger, endpoint string) *Crictl { +func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl { // Attempt to find the executable, but don't bother propagating the // error at this point. The first command executed will return with a // binary not found error. return &Crictl{ logger: logger, endpoint: endpoint, + runpArgs: runpArgs, } } @@ -67,8 +86,8 @@ func (cc *Crictl) CleanUp() { } // RunPod creates a sandbox. It corresponds to `crictl runp`. -func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { - podID, err := cc.run("runp", sbSpecFile) +func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) { + podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile) if err != nil { return "", fmt.Errorf("runp failed: %v", err) } @@ -79,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { // Create creates a container within a sandbox. It corresponds to `crictl // create`. func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) { - podID, err := cc.run("create", podID, contSpecFile, sbSpecFile) + // In version 1.16.0, crictl annoying starting attempting to pull the + // container, even if it was already available locally. We therefore + // need to parse the version and add an appropriate --no-pull argument + // since the image has already been loaded locally. + out, err := cc.run("-v") + if err != nil { + return "", err + } + r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(out) + if len(vs) != 4 { + return "", fmt.Errorf("crictl -v had unexpected output: %s", out) + } + major, err := strconv.ParseUint(vs[1], 10, 64) if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + + args := []string{"create"} + if (major == 1 && minor >= 16) || major > 1 { + args = append(args, "--no-pull") + } + args = append(args, podID) + args = append(args, contSpecFile) + args = append(args, sbSpecFile) + + podID, err = cc.run(args...) + if err != nil { + time.Sleep(10 * time.Minute) // XXX return "", fmt.Errorf("create failed: %v", err) } + // Strip the trailing newline from crictl output. return strings.TrimSpace(podID), nil } @@ -179,7 +230,7 @@ func (cc *Crictl) Import(image string) error { // be pushing a lot of bytes in order to import the image. The connect // timeout stays the same and is inherited from the Crictl instance. cmd := testutil.Command(cc.logger, - resolvePath("ctr"), + ResolvePath("ctr"), fmt.Sprintf("--connect-timeout=%s", 30*time.Second), fmt.Sprintf("--address=%s", cc.endpoint), "-n", "k8s.io", "images", "import", "-") @@ -260,7 +311,7 @@ func (cc *Crictl) StopContainer(contID string) error { // StartPodAndContainer starts a sandbox and container in that sandbox. It // returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { +func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) { if err := cc.Import(image); err != nil { return "", "", err } @@ -277,7 +328,7 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, } cc.cleanup = append(cc.cleanup, cleanup) - podID, err := cc.RunPod(sbSpecFile) + podID, err := cc.RunPod(runtime, sbSpecFile) if err != nil { return "", "", err } @@ -307,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error { // run runs crictl with the given args. func (cc *Crictl) run(args ...string) (string, error) { defaultArgs := []string{ - resolvePath("crictl"), + ResolvePath("crictl"), "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), } diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 7c8758e35..a5e84658a 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -1,14 +1,42 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "dockerutil", testonly = 1, - srcs = ["dockerutil.go"], + srcs = [ + "container.go", + "dockerutil.go", + "exec.go", + "network.go", + "profile.go", + ], visibility = ["//:sandbox"], deps = [ "//pkg/test/testutil", - "@com_github_kr_pty//:go_default_library", + "@com_github_docker_docker//api/types:go_default_library", + "@com_github_docker_docker//api/types/container:go_default_library", + "@com_github_docker_docker//api/types/mount:go_default_library", + "@com_github_docker_docker//api/types/network:go_default_library", + "@com_github_docker_docker//client:go_default_library", + "@com_github_docker_docker//pkg/stdcopy:go_default_library", + "@com_github_docker_go_connections//nat:go_default_library", + ], +) + +go_test( + name = "profile_test", + size = "large", + srcs = [ + "profile_test.go", + ], + library = ":dockerutil", + tags = [ + # Requires docker and runsc to be configured before test runs. + # Also requires the test to be run as root. + "manual", + "local", ], + visibility = ["//:sandbox"], ) diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md new file mode 100644 index 000000000..870292096 --- /dev/null +++ b/pkg/test/dockerutil/README.md @@ -0,0 +1,86 @@ +# dockerutil + +This package is for creating and controlling docker containers for testing +runsc, gVisor's docker/kubernetes binary. A simple test may look like: + +``` + func TestSuperCool(t *testing.T) { + ctx := context.Background() + c := dockerutil.MakeContainer(ctx, t) + got, err := c.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine" + }, "echo", "super cool") + if err != nil { + t.Fatalf("err was not nil: %v", err) + } + want := "super cool" + if !strings.Contains(got, want){ + t.Fatalf("want: %s, got: %s", want, got) + } + } +``` + +For further examples, see many of our end to end tests elsewhere in the repo, +such as those in //test/e2e or benchmarks at //test/benchmarks. + +dockerutil uses the "official" docker golang api, which is +[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil +is a thin wrapper around this API, allowing desired new use cases to be easily +implemented. + +## Profiling + +dockerutil is capable of generating profiles. Currently, the only option is to +use pprof profiles generated by `runsc debug`. The profiler will generate Block, +CPU, Heap, Goroutine, and Mutex profiles. To generate profiles: + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or + `--vfs2`. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles run: + +``` +make sudo TARGETS=//path/to:target \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` + +Container name in most tests and benchmarks in gVisor is usually the test name +and some random characters like so: +`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2` + +Profiling requires root as runsc debug inspects running containers in /var/run +among other things. + +### Writing for Profiling + +The below shows an example of using profiles with dockerutil. + +``` +func TestSuperCool(t *testing.T){ + ctx := context.Background() + // profiled and using runtime from dockerutil.runtime flag + profiled := MakeContainer() + + // not profiled and using runtime runc + native := MakeNativeContainer() + + err := profiled.Spawn(ctx, RunOpts{ + Image: "some/image", + }, "sleep", "100000") + // profiling has begun here + ... + expensive setup that I don't want to profile. + ... + profiled.RestartProfiles() + // profiled activity +} +``` + +In the above example, `profiled` would be profiled and `native` would not. The +call to `RestartProfiles()` restarts the clock on profiling. This is useful if +the main activity being tested is done with `docker exec` or `container.Spawn()` +followed by one or more `container.Exec()` calls. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go new file mode 100644 index 000000000..64d17f661 --- /dev/null +++ b/pkg/test/dockerutil/container.go @@ -0,0 +1,543 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net" + "os" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Container represents a Docker Container allowing +// user to configure and control as one would with the 'docker' +// client. Container is backed by the offical golang docker API. +// See: https://pkg.go.dev/github.com/docker/docker. +type Container struct { + Name string + runtime string + + logger testutil.Logger + client *client.Client + id string + mounts []mount.Mount + links []string + copyErr error + cleanups []func() + + // Profiles are profiles added to this container. They contain methods + // that are run after Creation, Start, and Cleanup of this Container, along + // a handle to restart the profile. Generally, tests/benchmarks using + // profiles need to run as root. + profiles []Profile +} + +// RunOpts are options for running a container. +type RunOpts struct { + // Image is the image relative to images/. This will be mangled + // appropriately, to ensure that only first-party images are used. + Image string + + // Memory is the memory limit in bytes. + Memory int + + // Cpus in which to allow execution. ("0", "1", "0-2"). + CpusetCpus string + + // Ports are the ports to be allocated. + Ports []int + + // WorkDir sets the working directory. + WorkDir string + + // ReadOnly sets the read-only flag. + ReadOnly bool + + // Env are additional environment variables. + Env []string + + // User is the user to use. + User string + + // Privileged enables privileged mode. + Privileged bool + + // CapAdd are the extra set of capabilities to add. + CapAdd []string + + // CapDrop are the extra set of capabilities to drop. + CapDrop []string + + // Mounts is the list of directories/files to be mounted inside the container. + Mounts []mount.Mount + + // Links is the list of containers to be connected to the container. + Links []string +} + +// MakeContainer sets up the struct for a Docker container. +// +// Names of containers will be unique. +// Containers will check flags for profiling requests. +func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := MakeNativeContainer(ctx, logger) + c.runtime = *runtime + if p := MakePprofFromFlags(c); p != nil { + c.AddProfile(p) + } + return c +} + +// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native +// containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return nil + } + client.NegotiateAPIVersion(ctx) + return &Container{ + logger: logger, + Name: name, + runtime: "", + client: client, + } +} + +// AddProfile adds a profile to this container. +func (c *Container) AddProfile(p Profile) { + c.profiles = append(c.profiles, p) +} + +// RestartProfiles calls Restart on all profiles for this container. +func (c *Container) RestartProfiles() error { + for _, profile := range c.profiles { + if err := profile.Restart(c); err != nil { + return err + } + } + return nil +} + +// Spawn is analogous to 'docker run -d'. +func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { + return err + } + return c.Start(ctx) +} + +// SpawnProcess is analogous to 'docker run -it'. It returns a process +// which represents the root process. +func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) { + config, hostconf, netconf := c.ConfigsFrom(r, args...) + config.Tty = true + config.OpenStdin = true + + if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil { + return Process{}, err + } + + // Open a connection to the container for parsing logs and for TTY. + stream, err := c.client.ContainerAttach(ctx, c.id, + types.ContainerAttachOptions{ + Stream: true, + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + return Process{}, fmt.Errorf("connect failed container id %s: %v", c.id, err) + } + + c.cleanups = append(c.cleanups, func() { stream.Close() }) + + if err := c.Start(ctx); err != nil { + return Process{}, err + } + + return Process{container: c, conn: stream}, nil +} + +// Run is analogous to 'docker run'. +func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { + return "", err + } + + if err := c.Start(ctx); err != nil { + return "", err + } + + if err := c.Wait(ctx); err != nil { + return "", err + } + + return c.Logs(ctx) +} + +// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom' +// and Start. +func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) { + return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{} +} + +// MakeLink formats a link to add to a RunOpts. +func (c *Container) MakeLink(target string) string { + return fmt.Sprintf("%s:%s", c.Name, target) +} + +// CreateFrom creates a container from the given configs. +func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + return c.create(ctx, conf, hostconf, netconf) +} + +// Create is analogous to 'docker create'. +func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { + return c.create(ctx, c.config(r, args), c.hostConfig(r), nil) +} + +func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) + if err != nil { + return err + } + c.id = cont.ID + for _, profile := range c.profiles { + if err := profile.OnCreate(c); err != nil { + return fmt.Errorf("OnCreate method failed with: %v", err) + } + } + return nil +} + +func (c *Container) config(r RunOpts, args []string) *container.Config { + ports := nat.PortSet{} + for _, p := range r.Ports { + port := nat.Port(fmt.Sprintf("%d", p)) + ports[port] = struct{}{} + } + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + + return &container.Config{ + Image: testutil.ImageByName(r.Image), + Cmd: args, + ExposedPorts: ports, + Env: env, + WorkingDir: r.WorkDir, + User: r.User, + } +} + +func (c *Container) hostConfig(r RunOpts) *container.HostConfig { + c.mounts = append(c.mounts, r.Mounts...) + + return &container.HostConfig{ + Runtime: c.runtime, + Mounts: c.mounts, + PublishAllPorts: true, + Links: r.Links, + CapAdd: r.CapAdd, + CapDrop: r.CapDrop, + Privileged: r.Privileged, + ReadonlyRootfs: r.ReadOnly, + Resources: container.Resources{ + Memory: int64(r.Memory), // In bytes. + CpusetCpus: r.CpusetCpus, + }, + } +} + +// Start is analogous to 'docker start'. +func (c *Container) Start(ctx context.Context) error { + if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { + return fmt.Errorf("ContainerStart failed: %v", err) + } + for _, profile := range c.profiles { + if err := profile.OnStart(c); err != nil { + return fmt.Errorf("OnStart method failed: %v", err) + } + } + return nil +} + +// Stop is analogous to 'docker stop'. +func (c *Container) Stop(ctx context.Context) error { + return c.client.ContainerStop(ctx, c.id, nil) +} + +// Pause is analogous to'docker pause'. +func (c *Container) Pause(ctx context.Context) error { + return c.client.ContainerPause(ctx, c.id) +} + +// Unpause is analogous to 'docker unpause'. +func (c *Container) Unpause(ctx context.Context) error { + return c.client.ContainerUnpause(ctx, c.id) +} + +// Checkpoint is analogous to 'docker checkpoint'. +func (c *Container) Checkpoint(ctx context.Context, name string) error { + return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true}) +} + +// Restore is analogous to 'docker start --checkname [name]'. +func (c *Container) Restore(ctx context.Context, name string) error { + return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name}) +} + +// Logs is analogous 'docker logs'. +func (c *Container) Logs(ctx context.Context) (string, error) { + var out bytes.Buffer + err := c.logs(ctx, &out, &out) + return out.String(), err +} + +func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error { + opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true} + writer, err := c.client.ContainerLogs(ctx, c.id, opts) + if err != nil { + return err + } + defer writer.Close() + _, err = stdcopy.StdCopy(stdout, stderr, writer) + + return err +} + +// ID returns the container id. +func (c *Container) ID() string { + return c.id +} + +// SandboxPid returns the container's pid. +func (c *Container) SandboxPid(ctx context.Context) (int, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, err + } + return resp.ContainerJSONBase.State.Pid, nil +} + +// FindIP returns the IP address of the container. +func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return nil, err + } + + var ip net.IP + if ipv6 { + ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.GlobalIPv6Address) + } else { + ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress) + } + if ip == nil { + return net.IP{}, fmt.Errorf("invalid IP: %q", ip) + } + return ip, nil +} + +// FindPort returns the host port that is mapped to 'sandboxPort'. +func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) { + desc, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, fmt.Errorf("error retrieving port: %v", err) + } + + format := fmt.Sprintf("%d/tcp", sandboxPort) + ports, ok := desc.NetworkSettings.Ports[nat.Port(format)] + if !ok { + return -1, fmt.Errorf("error retrieving port: %v", err) + + } + + port, err := strconv.Atoi(ports[0].HostPort) + if err != nil { + return -1, fmt.Errorf("error parsing port %q: %v", port, err) + } + return port, nil +} + +// CopyFiles copies in and mounts the given files. They are always ReadOnly. +func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { + dir, err := ioutil.TempDir("", c.Name) + if err != nil { + c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) + return + } + c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) }) + if err := os.Chmod(dir, 0755); err != nil { + c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) + return + } + for _, name := range sources { + src := name + if !filepath.IsAbs(src) { + src, err = testutil.FindFile(name) + if err != nil { + c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %w", name, err) + return + } + } + dst := path.Join(dir, path.Base(name)) + if err := testutil.Copy(src, dst); err != nil { + c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) + return + } + c.logger.Logf("copy: %s -> %s", src, dst) + } + opts.Mounts = append(opts.Mounts, mount.Mount{ + Type: mount.TypeBind, + Source: dir, + Target: target, + ReadOnly: false, + }) +} + +// Status inspects the container returns its status. +func (c *Container) Status(ctx context.Context) (types.ContainerState, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return types.ContainerState{}, err + } + return *resp.State, err +} + +// Wait waits for the container to exit. +func (c *Container) Wait(ctx context.Context) error { + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case err := <-errChan: + return err + case <-statusChan: + return nil + } +} + +// WaitTimeout waits for the container to exit with a timeout. +func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case <-ctx.Done(): + if ctx.Err() == context.DeadlineExceeded { + return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds()) + } + return nil + case err := <-errChan: + return err + case <-statusChan: + return nil + } +} + +// WaitForOutput searches container logs for pattern and returns or timesout. +func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) { + matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout) + if err != nil { + return "", err + } + if len(matches) == 0 { + return "", fmt.Errorf("didn't find pattern %s logs", pattern) + } + return matches[0], nil +} + +// WaitForOutputSubmatch searches container logs for the given +// pattern or times out. It returns any regexp submatches as well. +func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + re := regexp.MustCompile(pattern) + for { + logs, err := c.Logs(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get logs: %v logs: %s", err, logs) + } + if matches := re.FindStringSubmatch(logs); matches != nil { + return matches, nil + } + time.Sleep(50 * time.Millisecond) + } +} + +// Kill kills the container. +func (c *Container) Kill(ctx context.Context) error { + return c.client.ContainerKill(ctx, c.id, "") +} + +// Remove is analogous to 'docker rm'. +func (c *Container) Remove(ctx context.Context) error { + // Remove the image. + remove := types.ContainerRemoveOptions{ + RemoveVolumes: c.mounts != nil, + RemoveLinks: c.links != nil, + Force: true, + } + return c.client.ContainerRemove(ctx, c.Name, remove) +} + +// CleanUp kills and deletes the container (best effort). +func (c *Container) CleanUp(ctx context.Context) { + // Execute profile cleanups before the container goes down. + for _, profile := range c.profiles { + profile.OnCleanUp(c) + } + + // Forget profiles. + c.profiles = nil + + // Execute all cleanups. We execute cleanups here to close any + // open connections to the container before closing. Open connections + // can cause Kill and Remove to hang. + for _, c := range c.cleanups { + c() + } + c.cleanups = nil + + // Kill the container. + if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { + // Just log; can't do anything here. + c.logger.Logf("error killing container %q: %v", c.Name, err) + } + // Remove the image. + if err := c.Remove(ctx); err != nil { + c.logger.Logf("error removing container %q: %v", c.Name, err) + } + // Forget all mounts. + c.mounts = nil +} diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 819dd0a59..7027df1a5 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -22,17 +22,11 @@ import ( "io" "io/ioutil" "log" - "net" - "os" "os/exec" - "path" "regexp" "strconv" - "strings" - "syscall" "time" - "github.com/kr/pty" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -49,6 +43,25 @@ var ( // config is the default Docker daemon configuration path. config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") + + // The following flags are for the "pprof" profiler tool. + + // pprofBaseDir allows the user to change the directory to which profiles are + // written. By default, profiles will appear under: + // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + + // duration is the max duration `runsc debug` will run and capture profiles. + // If the container's clean up method is called prior to duration, the + // profiling process will be killed. + duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + + // The below flags enable each type of profile. Multiple profiles can be + // enabled for each run. + pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") + pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") + pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") + pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") ) // EnsureSupportedDockerVersion checks if correct docker is installed. @@ -74,44 +87,74 @@ func EnsureSupportedDockerVersion() { // RuntimePath returns the binary path for the current runtime. func RuntimePath() (string, error) { + rs, err := runtimeMap() + if err != nil { + return "", err + } + + p, ok := rs["path"].(string) + if !ok { + // The runtime does not declare a path. + return "", fmt.Errorf("runtime does not declare a path: %v", rs) + } + return p, nil +} + +// UsingVFS2 returns true if the 'runtime' has the vfs2 flag set. +// TODO(gvisor.dev/issue/1624): Remove. +func UsingVFS2() (bool, error) { + rMap, err := runtimeMap() + if err != nil { + return false, err + } + + list, ok := rMap["runtimeArgs"].([]interface{}) + if !ok { + return false, fmt.Errorf("unexpected format: %v", rMap) + } + + for _, element := range list { + if element == "--vfs2" { + return true, nil + } + } + return false, nil +} + +func runtimeMap() (map[string]interface{}, error) { // Read the configuration data; the file must exist. configBytes, err := ioutil.ReadFile(*config) if err != nil { - return "", err + return nil, err } // Unmarshal the configuration. c := make(map[string]interface{}) if err := json.Unmarshal(configBytes, &c); err != nil { - return "", err + return nil, err } // Decode the expected configuration. r, ok := c["runtimes"] if !ok { - return "", fmt.Errorf("no runtimes declared: %v", c) + return nil, fmt.Errorf("no runtimes declared: %v", c) } rs, ok := r.(map[string]interface{}) if !ok { // The runtimes are not a map. - return "", fmt.Errorf("unexpected format: %v", c) + return nil, fmt.Errorf("unexpected format: %v", rs) } r, ok = rs[*runtime] if !ok { // The expected runtime is not declared. - return "", fmt.Errorf("runtime %q not found: %v", *runtime, c) + return nil, fmt.Errorf("runtime %q not found: %v", *runtime, rs) } rs, ok = r.(map[string]interface{}) if !ok { // The runtime is not a map. - return "", fmt.Errorf("unexpected format: %v", c) + return nil, fmt.Errorf("unexpected format: %v", r) } - p, ok := rs["path"].(string) - if !ok { - // The runtime does not declare a path. - return "", fmt.Errorf("unexpected format: %v", c) - } - return p, nil + return rs, nil } // Save exports a container image to the given Writer. @@ -127,595 +170,7 @@ func Save(logger testutil.Logger, image string, w io.Writer) error { return cmd.Run() } -// MountMode describes if the mount should be ro or rw. -type MountMode int - -const ( - // ReadOnly is what the name says. - ReadOnly MountMode = iota - // ReadWrite is what the name says. - ReadWrite -) - -// String returns the mount mode argument for this MountMode. -func (m MountMode) String() string { - switch m { - case ReadOnly: - return "ro" - case ReadWrite: - return "rw" - } - panic(fmt.Sprintf("invalid mode: %d", m)) -} - -// DockerNetwork contains the name of a docker network. -type DockerNetwork struct { - logger testutil.Logger - Name string - Subnet *net.IPNet - containers []*Docker -} - -// NewDockerNetwork sets up the struct for a Docker network. Names of networks -// will be unique. -func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { - return &DockerNetwork{ - logger: logger, - Name: testutil.RandomID(logger.Name()), - } -} - -// Create calls 'docker network create'. -func (n *DockerNetwork) Create(args ...string) error { - a := []string{"docker", "network", "create"} - if n.Subnet != nil { - a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) - } - a = append(a, args...) - a = append(a, n.Name) - return testutil.Command(n.logger, a...).Run() -} - -// Connect calls 'docker network connect' with the arguments provided. -func (n *DockerNetwork) Connect(container *Docker, args ...string) error { - a := []string{"docker", "network", "connect"} - a = append(a, args...) - a = append(a, n.Name, container.Name) - if err := testutil.Command(n.logger, a...).Run(); err != nil { - return err - } - n.containers = append(n.containers, container) - return nil -} - -// Cleanup cleans up the docker network and all the containers attached to it. -func (n *DockerNetwork) Cleanup() error { - for _, c := range n.containers { - // Don't propagate the error, it might be that the container - // was already cleaned up. - if err := c.Kill(); err != nil { - n.logger.Logf("unable to kill container during cleanup: %s", err) - } - } - - if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { - return err - } - return nil -} - -// Docker contains the name and the runtime of a docker container. -type Docker struct { - logger testutil.Logger - Runtime string - Name string - copyErr error - cleanups []func() -} - -// MakeDocker sets up the struct for a Docker container. -// -// Names of containers will be unique. -func MakeDocker(logger testutil.Logger) *Docker { - // Slashes are not allowed in container names. - name := testutil.RandomID(logger.Name()) - name = strings.ReplaceAll(name, "/", "-") - - return &Docker{ - logger: logger, - Name: name, - Runtime: *runtime, - } -} - -// CopyFiles copies in and mounts the given files. They are always ReadOnly. -func (d *Docker) CopyFiles(opts *RunOpts, targetDir string, sources ...string) { - dir, err := ioutil.TempDir("", d.Name) - if err != nil { - d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) - return - } - d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) }) - if err := os.Chmod(dir, 0755); err != nil { - d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) - return - } - for _, name := range sources { - src, err := testutil.FindFile(name) - if err != nil { - d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) - return - } - dst := path.Join(dir, path.Base(name)) - if err := testutil.Copy(src, dst); err != nil { - d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) - return - } - d.logger.Logf("copy: %s -> %s", src, dst) - } - opts.Mounts = append(opts.Mounts, Mount{ - Source: dir, - Target: targetDir, - Mode: ReadOnly, - }) -} - -// Mount describes a mount point inside the container. -type Mount struct { - // Source is the path outside the container. - Source string - - // Target is the path inside the container. - Target string - - // Mode tells whether the mount inside the container should be readonly. - Mode MountMode -} - -// Link informs dockers that a given container needs to be made accessible from -// the container being configured. -type Link struct { - // Source is the container to connect to. - Source *Docker - - // Target is the alias for the container. - Target string -} - -// RunOpts are options for running a container. -type RunOpts struct { - // Image is the image relative to images/. This will be mangled - // appropriately, to ensure that only first-party images are used. - Image string - - // Memory is the memory limit in kB. - Memory int - - // Ports are the ports to be allocated. - Ports []int - - // WorkDir sets the working directory. - WorkDir string - - // ReadOnly sets the read-only flag. - ReadOnly bool - - // Env are additional environment variables. - Env []string - - // User is the user to use. - User string - - // Privileged enables privileged mode. - Privileged bool - - // CapAdd are the extra set of capabilities to add. - CapAdd []string - - // CapDrop are the extra set of capabilities to drop. - CapDrop []string - - // Pty indicates that a pty will be allocated. If this is non-nil, then - // this will run after start-up with the *exec.Command and Pty file - // passed in to the function. - Pty func(*exec.Cmd, *os.File) - - // Foreground indicates that the container should be run in the - // foreground. If this is true, then the output will be available as a - // return value from the Run function. - Foreground bool - - // Mounts is the list of directories/files to be mounted inside the container. - Mounts []Mount - - // Links is the list of containers to be connected to the container. - Links []Link - - // Extra are extra arguments that may be passed. - Extra []string -} - -// args returns common arguments. -// -// Note that this does not define the complete behavior. -func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { - isExec := command == "exec" - isRun := command == "run" - - if isRun || isExec { - rv = append(rv, "-i") - } - if r.Pty != nil { - rv = append(rv, "-t") - } - if r.User != "" { - rv = append(rv, fmt.Sprintf("--user=%s", r.User)) - } - if r.Privileged { - rv = append(rv, "--privileged") - } - for _, c := range r.CapAdd { - rv = append(rv, fmt.Sprintf("--cap-add=%s", c)) - } - for _, c := range r.CapDrop { - rv = append(rv, fmt.Sprintf("--cap-drop=%s", c)) - } - for _, e := range r.Env { - rv = append(rv, fmt.Sprintf("--env=%s", e)) - } - if r.WorkDir != "" { - rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir)) - } - if !isExec { - if r.Memory != 0 { - rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory)) - } - for _, p := range r.Ports { - rv = append(rv, fmt.Sprintf("--publish=%d", p)) - } - if r.ReadOnly { - rv = append(rv, fmt.Sprintf("--read-only")) - } - if len(p) > 0 { - rv = append(rv, "--entrypoint=") - } - } - - // Always attach the test environment & Extra. - rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name)) - rv = append(rv, r.Extra...) - - // Attach necessary bits. - if isExec { - rv = append(rv, d.Name) - } else { - for _, m := range r.Mounts { - rv = append(rv, fmt.Sprintf("-v=%s:%s:%v", m.Source, m.Target, m.Mode)) - } - for _, l := range r.Links { - rv = append(rv, fmt.Sprintf("--link=%s:%s", l.Source.Name, l.Target)) - } - - if len(d.Runtime) > 0 { - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) - } - rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) - rv = append(rv, testutil.ImageByName(r.Image)) - } - - // Attach other arguments. - rv = append(rv, p...) - return rv -} - -// run runs a complete command. -func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) { - if d.copyErr != nil { - return "", d.copyErr - } - basicArgs := []string{"docker"} - if command == "spawn" { - command = "run" - basicArgs = append(basicArgs, command) - basicArgs = append(basicArgs, "-d") - } else { - basicArgs = append(basicArgs, command) - } - customArgs := d.argsFor(&r, command, p) - cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...) - if r.Pty != nil { - // If allocating a terminal, then we just ignore the output - // from the command. - ptmx, err := pty.Start(cmd.Cmd) - if err != nil { - return "", err - } - defer cmd.Wait() // Best effort. - r.Pty(cmd.Cmd, ptmx) - } else { - // Can't support PTY or streaming. - out, err := cmd.CombinedOutput() - return string(out), err - } - return "", nil -} - -// Create calls 'docker create' with the arguments provided. -func (d *Docker) Create(r RunOpts, args ...string) error { - out, err := d.run(r, "create", args...) - if strings.Contains(out, "Unable to find image") { - return fmt.Errorf("unable to find image, did you remember to `make load-%s`: %w", r.Image, err) - } - return err -} - -// Start calls 'docker start'. -func (d *Docker) Start() error { - return testutil.Command(d.logger, "docker", "start", d.Name).Run() -} - -// Stop calls 'docker stop'. -func (d *Docker) Stop() error { - return testutil.Command(d.logger, "docker", "stop", d.Name).Run() -} - -// Run calls 'docker run' with the arguments provided. -func (d *Docker) Run(r RunOpts, args ...string) (string, error) { - return d.run(r, "run", args...) -} - -// Spawn starts the container and detaches. -func (d *Docker) Spawn(r RunOpts, args ...string) error { - _, err := d.run(r, "spawn", args...) - return err -} - -// Logs calls 'docker logs'. -func (d *Docker) Logs() (string, error) { - // Don't capture the output; since it will swamp the logs. - out, err := exec.Command("docker", "logs", d.Name).CombinedOutput() - return string(out), err -} - -// Exec calls 'docker exec' with the arguments provided. -func (d *Docker) Exec(r RunOpts, args ...string) (string, error) { - return d.run(r, "exec", args...) -} - -// Pause calls 'docker pause'. -func (d *Docker) Pause() error { - return testutil.Command(d.logger, "docker", "pause", d.Name).Run() -} - -// Unpause calls 'docker pause'. -func (d *Docker) Unpause() error { - return testutil.Command(d.logger, "docker", "unpause", d.Name).Run() -} - -// Checkpoint calls 'docker checkpoint'. -func (d *Docker) Checkpoint(name string) error { - return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run() -} - -// Restore calls 'docker start --checkname [name]'. -func (d *Docker) Restore(name string) error { - return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run() -} - -// Kill calls 'docker kill'. -func (d *Docker) Kill() error { - // Skip logging this command, it will likely be an error. - out, err := exec.Command("docker", "kill", d.Name).CombinedOutput() - if err != nil && !strings.Contains(string(out), "is not running") { - return err - } - return nil -} - -// Remove calls 'docker rm'. -func (d *Docker) Remove() error { - return testutil.Command(d.logger, "docker", "rm", d.Name).Run() -} - -// CleanUp kills and deletes the container (best effort). -func (d *Docker) CleanUp() { - // Kill the container. - if err := d.Kill(); err != nil { - // Just log; can't do anything here. - d.logger.Logf("error killing container %q: %v", d.Name, err) - } - // Remove the image. - if err := d.Remove(); err != nil { - d.logger.Logf("error removing container %q: %v", d.Name, err) - } - // Execute all cleanups. - for _, c := range d.cleanups { - c() - } - d.cleanups = nil -} - -// FindPort returns the host port that is mapped to 'sandboxPort'. This calls -// docker to allocate a free port in the host and prevent conflicts. -func (d *Docker) FindPort(sandboxPort int) (int, error) { - format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort) - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving port: %v", err) - } - port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing port %q: %v", out, err) - } - return port, nil -} - -// FindIP returns the IP address of the container. -func (d *Docker) FindIP() (net.IP, error) { - const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return net.IP{}, fmt.Errorf("error retrieving IP: %v", err) - } - ip := net.ParseIP(strings.TrimSpace(string(out))) - if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", string(out)) - } - return ip, nil -} - -// A NetworkInterface is container's network interface information. -type NetworkInterface struct { - IPv4 net.IP - MAC net.HardwareAddr -} - -// ListNetworks returns the network interfaces of the container, keyed by -// Docker network name. -func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { - const format = `{{json .NetworkSettings.Networks}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) - } - - networks := map[string]map[string]string{} - if err := json.Unmarshal(out, &networks); err != nil { - return nil, fmt.Errorf("error decoding network interfaces: %w", err) - } - - interfaces := map[string]NetworkInterface{} - for name, iface := range networks { - var netface NetworkInterface - - rawIP := strings.TrimSpace(iface["IPAddress"]) - if rawIP != "" { - ip := net.ParseIP(rawIP) - if ip == nil { - return nil, fmt.Errorf("invalid IP: %q", rawIP) - } - // Docker's IPAddress field is IPv4. The IPv6 address - // is stored in the GlobalIPv6Address field. - netface.IPv4 = ip - } - - rawMAC := strings.TrimSpace(iface["MacAddress"]) - if rawMAC != "" { - mac, err := net.ParseMAC(rawMAC) - if err != nil { - return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) - } - netface.MAC = mac - } - - interfaces[name] = netface - } - - return interfaces, nil -} - -// SandboxPid returns the PID to the sandbox process. -func (d *Docker) SandboxPid() (int, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving pid: %v", err) - } - pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing pid %q: %v", out, err) - } - return pid, nil -} - -// ID returns the container ID. -func (d *Docker) ID() (string, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput() - if err != nil { - return "", fmt.Errorf("error retrieving ID: %v", err) - } - return strings.TrimSpace(string(out)), nil -} - -// Wait waits for container to exit, up to the given timeout. Returns error if -// wait fails or timeout is hit. Returns the application return code otherwise. -// Note that the application may have failed even if err == nil, always check -// the exit code. -func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) { - timeoutChan := time.After(timeout) - waitChan := make(chan (syscall.WaitStatus)) - errChan := make(chan (error)) - - go func() { - out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput() - if err != nil { - errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err) - } - exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err) - } - waitChan <- syscall.WaitStatus(uint32(exit)) - }() - - select { - case ws := <-waitChan: - return ws, nil - case err := <-errChan: - return syscall.WaitStatus(1), err - case <-timeoutChan: - return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name) - } -} - -// WaitForOutput calls 'docker logs' to retrieve containers output and searches -// for the given pattern. -func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) { - matches, err := d.WaitForOutputSubmatch(pattern, timeout) - if err != nil { - return "", err - } - if len(matches) == 0 { - return "", nil - } - return matches[0], nil -} - -// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and -// searches for the given pattern. It returns any regexp submatches as well. -func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) { - re := regexp.MustCompile(pattern) - var ( - lastOut string - stopped bool - ) - for exp := time.Now().Add(timeout); time.Now().Before(exp); { - out, err := d.Logs() - if err != nil { - return nil, err - } - if out != lastOut { - if lastOut == "" { - d.logger.Logf("output (start): %s", out) - } else if strings.HasPrefix(out, lastOut) { - d.logger.Logf("output (contn): %s", out[len(lastOut):]) - } else { - d.logger.Logf("output (trunc): %s", out) - } - lastOut = out // Save for future. - if matches := re.FindStringSubmatch(lastOut); matches != nil { - return matches, nil // Success! - } - } else if stopped { - // The sandbox stopped and we looked at the - // logs at least once since determining that. - return nil, fmt.Errorf("no longer running: %v", err) - } else if pid, err := d.SandboxPid(); pid == 0 || err != nil { - // The sandbox may have stopped, but it's - // possible that it has emitted the terminal - // line between the last call to Logs and here. - stopped = true - } - time.Sleep(100 * time.Millisecond) - } - return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut) +// Runtime returns the value of the flag runtime. +func Runtime() string { + return *runtime } diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go new file mode 100644 index 000000000..4c739c9e9 --- /dev/null +++ b/pkg/test/dockerutil/exec.go @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/pkg/stdcopy" +) + +// ExecOpts holds arguments for Exec calls. +type ExecOpts struct { + // Env are additional environment variables. + Env []string + + // Privileged enables privileged mode. + Privileged bool + + // User is the user to use. + User string + + // Enables Tty and stdin for the created process. + UseTTY bool + + // WorkDir is the working directory of the process. + WorkDir string +} + +// Exec creates a process inside the container. +func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) { + p, err := c.doExec(ctx, opts, args) + if err != nil { + return "", err + } + + if exitStatus, err := p.WaitExitStatus(ctx); err != nil { + return "", err + } else if exitStatus != 0 { + out, _ := p.Logs() + return out, fmt.Errorf("process terminated with status: %d", exitStatus) + } + + return p.Logs() +} + +// ExecProcess creates a process inside the container and returns a process struct +// for the caller to use. +func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) { + return c.doExec(ctx, opts, args) +} + +func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) { + config := c.execConfig(r, args) + resp, err := c.client.ContainerExecCreate(ctx, c.id, config) + if err != nil { + return Process{}, fmt.Errorf("exec create failed with err: %v", err) + } + + hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{}) + if err != nil { + return Process{}, fmt.Errorf("exec attach failed with err: %v", err) + } + + if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil { + hijack.Close() + return Process{}, fmt.Errorf("exec start failed with err: %v", err) + } + + return Process{ + container: c, + execid: resp.ID, + conn: hijack, + }, nil +} + +func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig { + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + return types.ExecConfig{ + AttachStdin: r.UseTTY, + AttachStderr: true, + AttachStdout: true, + Cmd: cmd, + Privileged: r.Privileged, + WorkingDir: r.WorkDir, + Env: env, + Tty: r.UseTTY, + User: r.User, + } + +} + +// Process represents a containerized process. +type Process struct { + container *Container + execid string + conn types.HijackedResponse +} + +// Write writes buf to the process's stdin. +func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) { + p.conn.Conn.SetDeadline(time.Now().Add(timeout)) + return p.conn.Conn.Write(buf) +} + +// Read returns process's stdout and stderr. +func (p *Process) Read() (string, string, error) { + var stdout, stderr bytes.Buffer + if err := p.read(&stdout, &stderr); err != nil { + return "", "", err + } + return stdout.String(), stderr.String(), nil +} + +// Logs returns combined stdout/stderr from the process. +func (p *Process) Logs() (string, error) { + var out bytes.Buffer + if err := p.read(&out, &out); err != nil { + return "", err + } + return out.String(), nil +} + +func (p *Process) read(stdout, stderr *bytes.Buffer) error { + _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader) + return err +} + +// ExitCode returns the process's exit code. +func (p *Process) ExitCode(ctx context.Context) (int, error) { + _, exitCode, err := p.runningExitCode(ctx) + return exitCode, err +} + +// IsRunning checks if the process is running. +func (p *Process) IsRunning(ctx context.Context) (bool, error) { + running, _, err := p.runningExitCode(ctx) + return running, err +} + +// WaitExitStatus until process completes and returns exit status. +func (p *Process) WaitExitStatus(ctx context.Context) (int, error) { + waitChan := make(chan (int)) + errChan := make(chan (error)) + + go func() { + for { + running, exitcode, err := p.runningExitCode(ctx) + if err != nil { + errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name) + } + if !running { + waitChan <- exitcode + } + time.Sleep(time.Millisecond * 500) + } + }() + + select { + case ws := <-waitChan: + return ws, nil + case err := <-errChan: + return -1, err + } +} + +// runningExitCode collects if the process is running and the exit code. +// The exit code is only valid if the process has exited. +func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) { + // If execid is not empty, this is a execed process. + if p.execid != "" { + status, err := p.container.client.ContainerExecInspect(ctx, p.execid) + return status.Running, status.ExitCode, err + } + // else this is the root process. + status, err := p.container.Status(ctx) + return status.Running, status.ExitCode, err +} diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go new file mode 100644 index 000000000..047091e75 --- /dev/null +++ b/pkg/test/dockerutil/network.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "net" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Network is a docker network. +type Network struct { + client *client.Client + id string + logger testutil.Logger + Name string + containers []*Container + Subnet *net.IPNet +} + +// NewNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewNetwork(ctx context.Context, logger testutil.Logger) *Network { + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + logger.Logf("create client failed with: %v", err) + return nil + } + client.NegotiateAPIVersion(ctx) + + return &Network{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + client: client, + } +} + +func (n *Network) networkCreate() types.NetworkCreate { + + var subnet string + if n.Subnet != nil { + subnet = n.Subnet.String() + } + + ipam := network.IPAM{ + Config: []network.IPAMConfig{{ + Subnet: subnet, + }}, + } + + return types.NetworkCreate{ + CheckDuplicate: true, + IPAM: &ipam, + } +} + +// Create is analogous to 'docker network create'. +func (n *Network) Create(ctx context.Context) error { + + opts := n.networkCreate() + resp, err := n.client.NetworkCreate(ctx, n.Name, opts) + if err != nil { + return err + } + n.id = resp.ID + return nil +} + +// Connect is analogous to 'docker network connect' with the arguments provided. +func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error { + settings := network.EndpointSettings{ + IPAMConfig: &network.EndpointIPAMConfig{ + IPv4Address: ipv4, + IPv6Address: ipv6, + }, + } + err := n.client.NetworkConnect(ctx, n.id, container.id, &settings) + if err == nil { + n.containers = append(n.containers, container) + } + return err +} + +// Inspect returns this network's info. +func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { + return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *Network) Cleanup(ctx context.Context) error { + for _, c := range n.containers { + c.CleanUp(ctx) + } + n.containers = nil + + return n.client.NetworkRemove(ctx, n.id) +} diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go new file mode 100644 index 000000000..55f9496cd --- /dev/null +++ b/pkg/test/dockerutil/profile.go @@ -0,0 +1,147 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" +) + +// Profile represents profile-like operations on a container, +// such as running perf or pprof. It is meant to be added to containers +// such that the container type calls the Profile during its lifecycle. +type Profile interface { + // OnCreate is called just after the container is created when the container + // has a valid ID (e.g. c.ID()). + OnCreate(c *Container) error + + // OnStart is called just after the container is started when the container + // has a valid Pid (e.g. c.SandboxPid()). + OnStart(c *Container) error + + // Restart restarts the Profile on request. + Restart(c *Container) error + + // OnCleanUp is called during the container's cleanup method. + // Cleanups should just log errors if they have them. + OnCleanUp(c *Container) error +} + +// Pprof is for running profiles with 'runsc debug'. Pprof workloads +// should be run as root and ONLY against runsc sandboxes. The runtime +// should have --profile set as an option in /etc/docker/daemon.json in +// order for profiling to work with Pprof. +type Pprof struct { + BasePath string // path to put profiles + BlockProfile bool + CPUProfile bool + HeapProfile bool + MutexProfile bool + Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. + shouldRun bool + cmd *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser +} + +// MakePprofFromFlags makes a Pprof profile from flags. +func MakePprofFromFlags(c *Container) *Pprof { + if !(*pprofBlock || *pprofCPU || *pprofHeap || *pprofMutex) { + return nil + } + return &Pprof{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + BlockProfile: *pprofBlock, + CPUProfile: *pprofCPU, + HeapProfile: *pprofHeap, + MutexProfile: *pprofMutex, + Duration: *duration, + } +} + +// OnCreate implements Profile.OnCreate. +func (p *Pprof) OnCreate(c *Container) error { + return os.MkdirAll(p.BasePath, 0755) +} + +// OnStart implements Profile.OnStart. +func (p *Pprof) OnStart(c *Container) error { + path, err := RuntimePath() + if err != nil { + return fmt.Errorf("failed to get runtime path: %v", err) + } + + // The root directory of this container's runtime. + root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + args := []string{root, "debug"} + args = append(args, p.makeProfileArgs(c)...) + args = append(args, c.ID()) + + // Best effort wait until container is running. + for now := time.Now(); time.Since(now) < 5*time.Second; { + if status, err := c.Status(context.Background()); err != nil { + return fmt.Errorf("failed to get status with: %v", err) + + } else if status.Running { + break + } + time.Sleep(500 * time.Millisecond) + } + p.cmd = exec.Command(path, args...) + if err := p.cmd.Start(); err != nil { + return fmt.Errorf("process failed: %v", err) + } + return nil +} + +// Restart implements Profile.Restart. +func (p *Pprof) Restart(c *Container) error { + p.OnCleanUp(c) + return p.OnStart(c) +} + +// OnCleanUp implements Profile.OnCleanup +func (p *Pprof) OnCleanUp(c *Container) error { + defer func() { p.cmd = nil }() + if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { + return p.cmd.Process.Kill() + } + return nil +} + +// makeProfileArgs turns Pprof fields into runsc debug flags. +func (p *Pprof) makeProfileArgs(c *Container) []string { + var ret []string + if p.BlockProfile { + ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) + } + if p.CPUProfile { + ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) + } + if p.HeapProfile { + ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) + } + if p.MutexProfile { + ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) + } + ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) + return ret +} diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go new file mode 100644 index 000000000..8c4ffe483 --- /dev/null +++ b/pkg/test/dockerutil/profile_test.go @@ -0,0 +1,116 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +type testCase struct { + name string + pprof Pprof + expectedFiles []string +} + +func TestPprof(t *testing.T) { + // Basepath and expected file names for each type of profile. + basePath := "/tmp/test/profile" + block := "block.pprof" + cpu := "cpu.pprof" + goprofle := "go.pprof" + heap := "heap.pprof" + mutex := "mutex.pprof" + + testCases := []testCase{ + { + name: "Cpu", + pprof: Pprof{ + BasePath: basePath, + CPUProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{cpu}, + }, + { + name: "All", + pprof: Pprof{ + BasePath: basePath, + BlockProfile: true, + CPUProfile: true, + HeapProfile: true, + MutexProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. + tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) + c.AddProfile(&tc.pprof) + + func() { + defer c.CleanUp(ctx) + // Start a container. + if err := c.Spawn(ctx, RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { + t.Fatalf("run failed with: %v", err) + } + + if status, err := c.Status(context.Background()); !status.Running { + t.Fatalf("container is not yet running: %+v err: %v", status, err) + } + + // End early if the expected files exist and have data. + for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { + if err := checkFiles(tc); err == nil { + break + } + } + }() + + // Check all expected files exist and have data. + if err := checkFiles(tc); err != nil { + t.Fatalf(err.Error()) + } + }) + } +} + +func checkFiles(tc testCase) error { + for _, file := range tc.expectedFiles { + stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) + if err != nil { + return fmt.Errorf("stat failed with: %v", err) + } else if stat.Size() < 1 { + return fmt.Errorf("file not written to: %+v", stat) + } + } + return nil +} + +func TestMain(m *testing.M) { + EnsureSupportedDockerVersion() + os.Exit(m.Run()) +} diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD index 03b1b4677..c4b131896 100644 --- a/pkg/test/testutil/BUILD +++ b/pkg/test/testutil/BUILD @@ -12,9 +12,9 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/sync", - "//runsc/boot", + "//runsc/config", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index f21d6769a..49ab87c58 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -44,7 +44,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -133,25 +133,28 @@ func Command(logger Logger, args ...string) *Cmd { // TestConfig returns the default configuration to use in tests. Note that // 'RootDir' must be set by caller if required. -func TestConfig(t *testing.T) *boot.Config { +func TestConfig(t *testing.T) *config.Config { logDir := os.TempDir() if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { logDir = dir + "/" } - return &boot.Config{ - Debug: true, - DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"), - LogFormat: "text", - DebugLogFormat: "text", - LogPackets: true, - Network: boot.NetworkNone, - Strace: true, - Platform: "ptrace", - FileAccess: boot.FileAccessExclusive, - NumNetworkChannels: 1, - TestOnlyAllowRunAsCurrentUserWithoutChroot: true, - } + // Only register flags if config is being used. Otherwise anyone that uses + // testutil will get flags registered and they may conflict. + config.RegisterFlags() + + conf, err := config.NewFromFlags() + if err != nil { + panic(err) + } + // Change test defaults. + conf.Debug = true + conf.DebugLog = path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%") + conf.LogPackets = true + conf.Network = config.NetworkNone + conf.Strace = true + conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true + return conf } // NewSpecWithArgs creates a simple spec with the given args suitable for use @@ -203,7 +206,7 @@ func SetupRootDir() (string, func(), error) { // SetupContainer creates a bundle and root dir for the container, generates a // test config, and writes the spec to config.json in the bundle dir. -func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, cleanup func(), err error) { +func SetupContainer(spec *specs.Spec, conf *config.Config) (rootDir, bundleDir string, cleanup func(), err error) { rootDir, rootCleanup, err := SetupRootDir() if err != nil { return "", "", nil, err @@ -243,12 +246,15 @@ func writeSpec(dir string, spec *specs.Spec) error { return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755) } +// idRandomSrc is a pseudo random generator used to in RandomID. +var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano())) + // RandomID returns 20 random bytes following the given prefix. func RandomID(prefix string) string { // Read 20 random bytes. b := make([]byte, 20) // "[Read] always returns len(p) and a nil error." --godoc - if _, err := rand.Read(b); err != nil { + if _, err := idRandomSrc.Read(b); err != nil { panic("rand.Read failed: " + err.Error()) } if prefix != "" { @@ -264,7 +270,7 @@ func RandomID(prefix string) string { // same name, sometimes between test runs the socket does not get cleaned up // quickly enough, causing container creation to fail. func RandomContainerID() string { - return RandomID("test-container-") + return RandomID("test-container") } // Copy copies file from src to dst. @@ -316,18 +322,23 @@ func Copy(src, dst string) error { func Poll(cb func() error, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() + return PollContext(ctx, cb) +} + +// PollContext is like Poll, but takes a context instead of a timeout. +func PollContext(ctx context.Context, cb func() error) error { b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) return backoff.Retry(cb, b) } // WaitForHTTP tries GET requests on a port until the call succeeds or timeout. -func WaitForHTTP(port int, timeout time.Duration) error { +func WaitForHTTP(ip string, port int, timeout time.Duration) error { cb := func() error { c := &http.Client{ // Calculate timeout to be able to do minimum 5 attempts. Timeout: timeout / 5, } - url := fmt.Sprintf("http://localhost:%d/", port) + url := fmt.Sprintf("http://%s:%d/", ip, port) resp, err := c.Get(url) if err != nil { log.Printf("Waiting %s: %v", url, err) @@ -482,6 +493,21 @@ func IsStatic(filename string) (bool, error) { return true, nil } +// TouchShardStatusFile indicates to Bazel that the test runner supports +// sharding by creating or updating the last modified date of the file +// specified by TEST_SHARD_STATUS_FILE. +// +// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner. +func TouchShardStatusFile() error { + if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" { + cmd := exec.Command("touch", statusFile) + if b, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error()) + } + } + return nil +} + // TestIndicesForShard returns indices for this test shard based on the // TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. // diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go index d843f19cf..c976d7230 100644 --- a/pkg/unet/unet.go +++ b/pkg/unet/unet.go @@ -522,7 +522,7 @@ func (s *ServerSocket) Listen() error { // This is always blocking. // // Preconditions: -// * ServerSocket is listening (Listen called). +// * ServerSocket is listening (Listen called). func (s *ServerSocket) Accept() (*Socket, error) { fd, ok := s.socket.enterFD() if !ok { diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/usermem/addr_range_seq_unsafe.go index c09337c15..495896ded 100644 --- a/pkg/usermem/addr_range_seq_unsafe.go +++ b/pkg/usermem/addr_range_seq_unsafe.go @@ -81,8 +81,10 @@ func AddrRangeSeqFromSlice(slice []AddrRange) AddrRangeSeq { return addrRangeSeqFromSliceLimited(slice, limit) } -// Preconditions: The combined length of all AddrRanges in slice <= limit. -// limit >= 0. If len(slice) != 0, then limit > 0. +// Preconditions: +// * The combined length of all AddrRanges in slice <= limit. +// * limit >= 0. +// * If len(slice) != 0, then limit > 0. func addrRangeSeqFromSliceLimited(slice []AddrRange, limit int64) AddrRangeSeq { switch len(slice) { case 0: diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index cd6a0ea6b..9b1e7a085 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -21,7 +21,6 @@ import ( "io" "strconv" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safemem" @@ -54,8 +53,10 @@ type IO interface { // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a // non-nil error explaining why. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. toZero >= 0. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * toZero >= 0. ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at @@ -66,9 +67,11 @@ type IO interface { // // CopyOutFrom calls src.ReadToBlocks at most once. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. src.ReadToBlocks must not block - // on mm.MemoryManager.activeMu or any preceding locks in the lock order. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * src.ReadToBlocks must not block on mm.MemoryManager.activeMu or + // any preceding locks in the lock order. CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to @@ -78,10 +81,11 @@ type IO interface { // // CopyInTo calls dst.WriteFromBlocks at most once. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. dst.WriteFromBlocks must not - // block on mm.MemoryManager.activeMu or any preceding locks in the lock - // order. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * dst.WriteFromBlocks must not block on mm.MemoryManager.activeMu or + // any preceding locks in the lock order. CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst @@ -93,25 +97,28 @@ type IO interface { // SwapUint32 atomically sets the uint32 value at addr to new and // returns the previous value. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * addr must be aligned to a 4-byte boundary. SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) // CompareAndSwapUint32 atomically compares the uint32 value at addr to // old; if they are equal, the value in memory is replaced by new. In // either case, the previous value stored in memory is returned. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * addr must be aligned to a 4-byte boundary. CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) // LoadUint32 atomically loads the uint32 value at addr and returns it. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * addr must be aligned to a 4-byte boundary. LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) } @@ -176,51 +183,6 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) { return n, err } -// CopyObjectOut copies a fixed-size value or slice of fixed-size values from -// src to the memory mapped at addr in uio. It returns the number of bytes -// copied. -// -// CopyObjectOut must use reflection to encode src; performance-sensitive -// clients should do encoding manually and use uio.CopyOut directly. -// -// Preconditions: As for IO.CopyOut. -func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) { - w := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - // Allocate a byte slice the size of the object being marshaled. This - // adds an extra reflection call, but avoids needing to grow the slice - // during encoding, which can result in many heap-allocated slices. - b := make([]byte, 0, binary.Size(src)) - return w.Write(binary.Marshal(b, ByteOrder, src)) -} - -// CopyObjectIn copies a fixed-size value or slice of fixed-size values from -// the memory mapped at addr in uio to dst. It returns the number of bytes -// copied. -// -// CopyObjectIn must use reflection to decode dst; performance-sensitive -// clients should use uio.CopyIn directly and do decoding manually. -// -// Preconditions: As for IO.CopyIn. -func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) { - r := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - buf := make([]byte, binary.Size(dst)) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - binary.Unmarshal(buf, ByteOrder, dst) - return int(r.Addr - addr), nil -} - // CopyStringIn tuning parameters, defined outside that function for tests. const ( copyStringIncrement = 64 @@ -233,7 +195,8 @@ const ( // would exceed maxlen, CopyStringIn returns the string truncated to maxlen and // ENAMETOOLONG. // -// Preconditions: As for IO.CopyFromUser. maxlen >= 0. +// Preconditions: Same as IO.CopyFromUser, plus: +// * maxlen >= 0. func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) { initLen := maxlen if initLen > copyStringMaxInitBufLen { @@ -287,7 +250,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt // less. CopyOutVec returns the number of bytes copied; if this is less than // the maximum, it returns a non-nil error explaining why. // -// Preconditions: As for IO.CopyOut. +// Preconditions: Same as IO.CopyOut. func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) { var done int for !ars.IsEmpty() && done < len(src) { @@ -311,7 +274,7 @@ func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts // less. CopyInVec returns the number of bytes copied; if this is less than the // maximum, it returns a non-nil error explaining why. // -// Preconditions: As for IO.CopyIn. +// Preconditions: Same as IO.CopyIn. func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) { var done int for !ars.IsEmpty() && done < len(dst) { @@ -335,7 +298,7 @@ func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts I // ZeroOutVec returns the number of bytes written; if this is less than the // maximum, it returns a non-nil error explaining why. // -// Preconditions: As for IO.ZeroOut. +// Preconditions: Same as IO.ZeroOut. func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) { var done int64 for !ars.IsEmpty() && done < toZero { @@ -388,7 +351,7 @@ func isASCIIWhitespace(b byte) bool { // // - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0. // -// Preconditions: As for CopyInVec. +// Preconditions: Same as CopyInVec. func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) { if len(dsts) == 0 { return 0, nil @@ -481,28 +444,28 @@ func (s IOSequence) NumBytes() int64 { // DropFirst returns a copy of s with s.Addrs.DropFirst(n). // -// Preconditions: As for AddrRangeSeq.DropFirst. +// Preconditions: Same as AddrRangeSeq.DropFirst. func (s IOSequence) DropFirst(n int) IOSequence { return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts} } // DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n). // -// Preconditions: As for AddrRangeSeq.DropFirst64. +// Preconditions: Same as AddrRangeSeq.DropFirst64. func (s IOSequence) DropFirst64(n int64) IOSequence { return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts} } // TakeFirst returns a copy of s with s.Addrs.TakeFirst(n). // -// Preconditions: As for AddrRangeSeq.TakeFirst. +// Preconditions: Same as AddrRangeSeq.TakeFirst. func (s IOSequence) TakeFirst(n int) IOSequence { return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts} } // TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n). // -// Preconditions: As for AddrRangeSeq.TakeFirst64. +// Preconditions: Same as AddrRangeSeq.TakeFirst64. func (s IOSequence) TakeFirst64(n int64) IOSequence { return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts} } @@ -512,7 +475,7 @@ func (s IOSequence) TakeFirst64(n int64) IOSequence { // As with CopyOutVec, if s.NumBytes() < len(src), the copy will be truncated // to s.NumBytes(), and a nil error will be returned. // -// Preconditions: As for CopyOutVec. +// Preconditions: Same as CopyOutVec. func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) { return CopyOutVec(ctx, s.IO, s.Addrs, src, s.Opts) } @@ -522,7 +485,7 @@ func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) { // As with CopyInVec, if s.NumBytes() < len(dst), the copy will be truncated to // s.NumBytes(), and a nil error will be returned. // -// Preconditions: As for CopyInVec. +// Preconditions: Same as CopyInVec. func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) { return CopyInVec(ctx, s.IO, s.Addrs, dst, s.Opts) } @@ -532,21 +495,21 @@ func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) { // As with ZeroOutVec, if s.NumBytes() < toZero, the write will be truncated // to s.NumBytes(), and a nil error will be returned. // -// Preconditions: As for ZeroOutVec. +// Preconditions: Same as ZeroOutVec. func (s IOSequence) ZeroOut(ctx context.Context, toZero int64) (int64, error) { return ZeroOutVec(ctx, s.IO, s.Addrs, toZero, s.Opts) } // CopyOutFrom invokes s.CopyOutFrom over s.Addrs. // -// Preconditions: As for IO.CopyOutFrom. +// Preconditions: Same as IO.CopyOutFrom. func (s IOSequence) CopyOutFrom(ctx context.Context, src safemem.Reader) (int64, error) { return s.IO.CopyOutFrom(ctx, s.Addrs, src, s.Opts) } // CopyInTo invokes s.CopyInTo over s.Addrs. // -// Preconditions: As for IO.CopyInTo. +// Preconditions: Same as IO.CopyInTo. func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, error) { return s.IO.CopyInTo(ctx, s.Addrs, dst, s.Opts) } diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go index bf3c5df2b..da60b0cc7 100644 --- a/pkg/usermem/usermem_test.go +++ b/pkg/usermem/usermem_test.go @@ -16,7 +16,6 @@ package usermem import ( "bytes" - "encoding/binary" "fmt" "reflect" "strings" @@ -174,23 +173,6 @@ type testStruct struct { Uint64 uint64 } -func TestCopyObject(t *testing.T) { - wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} - wantN := binary.Size(wantObj) - b := &BytesIO{make([]byte, wantN)} - ctx := newContext() - if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { - t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - var gotObj testStruct - if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { - t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if gotObj != wantObj { - t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) - } -} - func TestCopyStringInShort(t *testing.T) { // Tests for string length <= copyStringIncrement. want := strings.Repeat("A", copyStringIncrement-2) diff --git a/runsc/BUILD b/runsc/BUILD index 757f6d44c..33d8554af 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -17,8 +17,8 @@ go_binary( "//pkg/log", "//pkg/refs", "//pkg/sentry/platform", - "//runsc/boot", "//runsc/cmd", + "//runsc/config", "//runsc/flag", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", @@ -53,66 +53,14 @@ go_binary( "//pkg/log", "//pkg/refs", "//pkg/sentry/platform", - "//runsc/boot", "//runsc/cmd", + "//runsc/config", "//runsc/flag", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", ], ) -pkg_tar( - name = "runsc-bin", - srcs = [":runsc"], - mode = "0755", - package_dir = "/usr/bin", - strip_prefix = "/runsc/linux_amd64_pure_stripped", -) - -pkg_tar( - name = "debian-data", - extension = "tar.gz", - deps = [ - ":runsc-bin", - ], -) - -genrule( - name = "deb-version", - # Note that runsc must appear in the srcs parameter and not the tools - # parameter, otherwise it will not be stamped. This is reasonable, as tools - # may be encoded differently in the build graph (cached more aggressively - # because they are assumes to be hermetic). - srcs = [":runsc"], - outs = ["version.txt"], - # Note that the little dance here is necessary because files in the $(SRCS) - # attribute are not executable by default, and we can't touch in place. - cmd = "cp $(location :runsc) $(@D)/runsc && \ - chmod a+x $(@D)/runsc && \ - $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ - rm -f $(@D)/runsc", - stamp = 1, -) - -pkg_deb( - name = "runsc-debian", - architecture = "amd64", - data = ":debian-data", - # Note that the description_file will be flatten (all newlines removed), - # and therefore it is kept to a simple one-line description. The expected - # format for debian packages is "short summary\nLonger explanation of - # tool." and this is impossible with the flattening. - description_file = "debian/description", - homepage = "https://gvisor.dev/", - maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", - package = "runsc", - postinst = "debian/postinst.sh", - version_file = ":version.txt", - visibility = [ - "//visibility:public", - ], -) - sh_test( name = "version_test", size = "small", diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index aad2a41de..2d9517f4a 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -8,7 +8,6 @@ go_library( "compat.go", "compat_amd64.go", "compat_arm64.go", - "config.go", "controller.go", "debug.go", "events.go", @@ -27,10 +26,13 @@ go_library( deps = [ "//pkg/abi", "//pkg/abi/linux", + "//pkg/bpf", + "//pkg/cleanup", "//pkg/context", "//pkg/control/server", "//pkg/cpuid", "//pkg/eventchannel", + "//pkg/fd", "//pkg/fspath", "//pkg/log", "//pkg/memutil", @@ -90,6 +92,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", @@ -104,9 +107,11 @@ go_library( "//runsc/boot/filter", "//runsc/boot/platforms", "//runsc/boot/pprof", + "//runsc/config", "//runsc/specutils", + "//runsc/specutils/seccomp", "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -122,6 +127,7 @@ go_test( library = ":boot", deps = [ "//pkg/control/server", + "//pkg/fd", "//pkg/fspath", "//pkg/log", "//pkg/p9", @@ -130,8 +136,9 @@ go_test( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/unet", + "//runsc/config", "//runsc/fsgofer", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 8125d5061..894651519 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -22,6 +22,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/control/server" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -33,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot/pprof" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -101,14 +103,13 @@ const ( // Profiling related commands (see pprof.go for more details). const ( - StartCPUProfile = "Profile.StartCPUProfile" - StopCPUProfile = "Profile.StopCPUProfile" - HeapProfile = "Profile.HeapProfile" - GoroutineProfile = "Profile.GoroutineProfile" - BlockProfile = "Profile.BlockProfile" - MutexProfile = "Profile.MutexProfile" - StartTrace = "Profile.StartTrace" - StopTrace = "Profile.StopTrace" + StartCPUProfile = "Profile.StartCPUProfile" + StopCPUProfile = "Profile.StopCPUProfile" + HeapProfile = "Profile.HeapProfile" + BlockProfile = "Profile.BlockProfile" + MutexProfile = "Profile.MutexProfile" + StartTrace = "Profile.StartTrace" + StopTrace = "Profile.StopTrace" ) // Logging related commands (see logging.go for more details). @@ -129,42 +130,52 @@ type controller struct { // manager holds the containerManager methods. manager *containerManager + + // pprop holds the profile instance if enabled. It may be nil. + pprof *control.Profile } // newController creates a new controller. The caller must call // controller.srv.StartServing() to start the controller. func newController(fd int, l *Loader) (*controller, error) { - srv, err := server.CreateFromFD(fd) + ctrl := &controller{} + var err error + ctrl.srv, err = server.CreateFromFD(fd) if err != nil { return nil, err } - manager := &containerManager{ + ctrl.manager = &containerManager{ startChan: make(chan struct{}), startResultChan: make(chan error), l: l, } - srv.Register(manager) + ctrl.srv.Register(ctrl.manager) if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok { net := &Network{ Stack: eps.Stack, } - srv.Register(net) + ctrl.srv.Register(net) } - srv.Register(&debug{}) - srv.Register(&control.Logging{}) - if l.conf.ProfileEnable { - srv.Register(&control.Profile{ - Kernel: l.k, - }) + ctrl.srv.Register(&debug{}) + ctrl.srv.Register(&control.Logging{}) + + if l.root.conf.ProfileEnable { + ctrl.pprof = &control.Profile{Kernel: l.k} + ctrl.srv.Register(ctrl.pprof) } - return &controller{ - srv: srv, - manager: manager, - }, nil + return ctrl, nil +} + +func (c *controller) stop() { + if c.pprof != nil { + // These are noop if there is nothing being profiled. + _ = c.pprof.StopCPUProfile(nil, nil) + _ = c.pprof.StopTrace(nil, nil) + } } // containerManager manages sandbox containers. @@ -211,7 +222,7 @@ type StartArgs struct { Spec *specs.Spec // Config is the runsc-specific configuration for the sandbox. - Conf *Config + Conf *config.Config // CID is the ID of the container to start. CID string @@ -247,13 +258,20 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { // All validation passed, logs the spec for debugging. specutils.LogSpec(args.Spec) - err := cm.l.startContainer(args.Spec, args.Conf, args.CID, args.FilePayload.Files) + fds, err := fd.NewFromFiles(args.FilePayload.Files) if err != nil { + return err + } + defer func() { + for _, fd := range fds { + _ = fd.Close() + } + }() + if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil { log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err) return err } log.Debugf("Container %q started", args.CID) - return nil } @@ -333,7 +351,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Pause the kernel while we build a new one. cm.l.k.Pause() - p, err := createPlatform(cm.l.conf, deviceFile) + p, err := createPlatform(cm.l.root.conf, deviceFile) if err != nil { return fmt.Errorf("creating platform: %v", err) } @@ -349,8 +367,8 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { cm.l.k = k // Set up the restore environment. - mntr := newContainerMounter(cm.l.spec, cm.l.goferFDs, cm.l.k, cm.l.mountHints) - renv, err := mntr.createRestoreEnvironment(cm.l.conf) + mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints) + renv, err := mntr.createRestoreEnvironment(cm.l.root.conf) if err != nil { return fmt.Errorf("creating RestoreEnvironment: %v", err) } @@ -368,7 +386,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { return fmt.Errorf("file cannot be empty") } - if cm.l.conf.ProfileEnable { + if cm.l.root.conf.ProfileEnable { // pprof.Initialize opens /proc/self/maps, so has to be called before // installing seccomp filters. pprof.Initialize() @@ -387,13 +405,13 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Since we have a new kernel we also must make a new watchdog. dogOpts := watchdog.DefaultOpts - dogOpts.TaskTimeoutAction = cm.l.conf.WatchdogAction + dogOpts.TaskTimeoutAction = cm.l.root.conf.WatchdogAction dog := watchdog.New(k, dogOpts) // Change the loader fields to reflect the changes made when restoring. cm.l.k = k cm.l.watchdog = dog - cm.l.rootProcArgs = kernel.CreateProcessArgs{} + cm.l.root.procArgs = kernel.CreateProcessArgs{} cm.l.restore = true // Reinitialize the sandbox ID and processes map. Note that it doesn't diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 60e33425f..6ac19668f 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -27,41 +27,30 @@ import ( // allowedSyscalls is the set of syscalls executed by the Sentry to the host OS. var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_CLOCK_GETTIME: {}, - syscall.SYS_CLONE: []seccomp.Rule{ - { - seccomp.AllowValue( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), - }, - }, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, + syscall.SYS_CLOSE: {}, + syscall.SYS_DUP: {}, syscall.SYS_DUP3: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.O_CLOEXEC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.O_CLOEXEC), }, }, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, }, syscall.SYS_EVENTFD2: []seccomp.Rule{ { - seccomp.AllowValue(0), - seccomp.AllowValue(0), + seccomp.EqualTo(0), + seccomp.EqualTo(0), }, }, syscall.SYS_EXIT: {}, @@ -70,16 +59,16 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_FCHMOD: {}, syscall.SYS_FCNTL: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_SETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_SETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFD), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFD), }, }, syscall.SYS_FSTAT: {}, @@ -87,52 +76,52 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_FTRUNCATE: {}, syscall.SYS_FUTEX: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, }, // Non-private variants are included for flipcall support. They are otherwise // unncessary, as the sentry will use only private futexes internally. { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE), - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE), + seccomp.MatchAny{}, }, }, syscall.SYS_GETPID: {}, unix.SYS_GETRANDOM: {}, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_DOMAIN), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_DOMAIN), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_TYPE), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_TYPE), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_ERROR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_ERROR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_SNDBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), }, }, syscall.SYS_GETTID: {}, @@ -141,34 +130,34 @@ var allowedSyscalls = seccomp.SyscallRules{ // setting/getting termios and winsize. syscall.SYS_IOCTL: []seccomp.Rule{ { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCGETS), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCGETS), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCSETS), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCSETS), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCSETSF), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCSETSF), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCSETSW), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCSETSW), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TIOCSWINSZ), - seccomp.AllowAny{}, /* winsize struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TIOCSWINSZ), + seccomp.MatchAny{}, /* winsize struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TIOCGWINSZ), - seccomp.AllowAny{}, /* winsize struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TIOCGWINSZ), + seccomp.MatchAny{}, /* winsize struct */ }, }, syscall.SYS_LSEEK: {}, @@ -182,46 +171,46 @@ var allowedSyscalls = seccomp.SyscallRules{ // TODO(b/148688965): Remove once this is gone from Go. syscall.SYS_MLOCK: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(4096), + seccomp.MatchAny{}, + seccomp.EqualTo(4096), }, }, syscall.SYS_MMAP: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_SHARED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_SHARED), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.PROT_WRITE | syscall.PROT_READ), - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.PROT_WRITE | syscall.PROT_READ), + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), }, }, syscall.SYS_MPROTECT: {}, @@ -237,32 +226,32 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_READ: {}, syscall.SYS_RECVMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), }, }, syscall.SYS_RECVMMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(fdbased.MaxMsgsPerRecv), - seccomp.AllowValue(syscall.MSG_DONTWAIT), - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(fdbased.MaxMsgsPerRecv), + seccomp.EqualTo(syscall.MSG_DONTWAIT), + seccomp.EqualTo(0), }, }, unix.SYS_SENDMMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT), - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT), + seccomp.EqualTo(0), }, }, syscall.SYS_RESTART_SYSCALL: {}, @@ -272,57 +261,50 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_SCHED_YIELD: {}, syscall.SYS_SENDMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), }, }, syscall.SYS_SETITIMER: {}, syscall.SYS_SHUTDOWN: []seccomp.Rule{ // Used by fs/host to shutdown host sockets. - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RD)}, - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_WR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RD)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_WR)}, // Used by unet to shutdown connections. - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, unix.SYS_STATX: {}, syscall.SYS_SYNC_FILE_RANGE: {}, syscall.SYS_TEE: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(1), /* len */ - seccomp.AllowValue(unix.SPLICE_F_NONBLOCK), /* flags */ + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(1), /* len */ + seccomp.EqualTo(unix.SPLICE_F_NONBLOCK), /* flags */ }, }, syscall.SYS_TGKILL: []seccomp.Rule{ { - seccomp.AllowValue(uint64(os.Getpid())), + seccomp.EqualTo(uint64(os.Getpid())), }, }, syscall.SYS_UTIMENSAT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(0), /* null pathname */ - seccomp.AllowAny{}, - seccomp.AllowValue(0), /* flags */ + seccomp.MatchAny{}, + seccomp.EqualTo(0), /* null pathname */ + seccomp.MatchAny{}, + seccomp.EqualTo(0), /* flags */ }, }, syscall.SYS_WRITE: {}, - // The only user in rawfile.NonBlockingWrite3 always passes iovcnt with - // values 2 or 3. Three iovec-s are passed, when the PACKET_VNET_HDR - // option is enabled for a packet socket. + // For rawfile.NonBlockingWriteIovec. syscall.SYS_WRITEV: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(2), - }, - { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(3), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.GreaterThan(0), }, }, } @@ -332,10 +314,10 @@ func hostInetFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ syscall.SYS_ACCEPT4: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), }, }, syscall.SYS_BIND: {}, @@ -344,84 +326,84 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_GETSOCKNAME: {}, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_TOS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_TOS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_RECVTOS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVTOS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_TCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_TCLASS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_RECVTCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVTCLASS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_V6ONLY), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_V6ONLY), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_ERROR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_ERROR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_KEEPALIVE), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_KEEPALIVE), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_SNDBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_RCVBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_RCVBUF), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_REUSEADDR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_REUSEADDR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_TYPE), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_TYPE), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_LINGER), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_LINGER), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_TCP), - seccomp.AllowValue(syscall.TCP_NODELAY), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(syscall.TCP_NODELAY), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_TCP), - seccomp.AllowValue(syscall.TCP_INFO), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(syscall.TCP_INFO), }, }, syscall.SYS_IOCTL: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.TIOCOUTQ), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.TIOCOUTQ), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.TIOCINQ), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.TIOCINQ), }, }, syscall.SYS_LISTEN: {}, @@ -432,103 +414,103 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_SENDTO: {}, syscall.SYS_SETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_V6ONLY), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_SNDBUF), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_RCVBUF), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_RCVBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_REUSEADDR), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_REUSEADDR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_TCP), - seccomp.AllowValue(syscall.TCP_NODELAY), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(syscall.TCP_NODELAY), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_TOS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_TOS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_RECVTOS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVTOS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_TCLASS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_TCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_RECVTCLASS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVTCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SHUT_RD), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SHUT_RD), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SHUT_WR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SHUT_WR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SHUT_RDWR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SHUT_RDWR), }, }, syscall.SYS_SOCKET: []seccomp.Rule{ { - seccomp.AllowValue(syscall.AF_INET), - seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET), + seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_INET), - seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET), + seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_INET6), - seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET6), + seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_INET6), - seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET6), + seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, }, syscall.SYS_WRITEV: {}, @@ -539,20 +521,20 @@ func controlServerFilters(fd int) seccomp.SyscallRules { return seccomp.SyscallRules{ syscall.SYS_ACCEPT: []seccomp.Rule{ { - seccomp.AllowValue(fd), + seccomp.EqualTo(fd), }, }, syscall.SYS_LISTEN: []seccomp.Rule{ { - seccomp.AllowValue(fd), - seccomp.AllowValue(16 /* unet.backlog */), + seccomp.EqualTo(fd), + seccomp.EqualTo(16 /* unet.backlog */), }, }, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_PEERCRED), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_PEERCRED), }, }, } diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go index 5335ff82c..cea5613b8 100644 --- a/runsc/boot/filter/config_amd64.go +++ b/runsc/boot/filter/config_amd64.go @@ -24,8 +24,41 @@ import ( ) func init() { - allowedSyscalls[syscall.SYS_ARCH_PRCTL] = append(allowedSyscalls[syscall.SYS_ARCH_PRCTL], - seccomp.Rule{seccomp.AllowValue(linux.ARCH_GET_FS)}, - seccomp.Rule{seccomp.AllowValue(linux.ARCH_SET_FS)}, - ) + allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ + // TODO(b/168828518): No longer used in Go 1.16+. + {seccomp.EqualTo(linux.ARCH_SET_FS)}, + } + + allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + // parent_tidptr and child_tidptr are always 0 because neither + // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used. + { + seccomp.EqualTo( + syscall.CLONE_VM | + syscall.CLONE_FS | + syscall.CLONE_FILES | + syscall.CLONE_SETTLS | + syscall.CLONE_SIGHAND | + syscall.CLONE_SYSVSEM | + syscall.CLONE_THREAD), + seccomp.MatchAny{}, // newsp + seccomp.EqualTo(0), // parent_tidptr + seccomp.EqualTo(0), // child_tidptr + seccomp.MatchAny{}, // tls + }, + { + // TODO(b/168828518): No longer used in Go 1.16+ (on amd64). + seccomp.EqualTo( + syscall.CLONE_VM | + syscall.CLONE_FS | + syscall.CLONE_FILES | + syscall.CLONE_SIGHAND | + syscall.CLONE_SYSVSEM | + syscall.CLONE_THREAD), + seccomp.MatchAny{}, // newsp + seccomp.EqualTo(0), // parent_tidptr + seccomp.EqualTo(0), // child_tidptr + seccomp.MatchAny{}, // tls + }, + } } diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go index 7fa9bbda3..37313f97f 100644 --- a/runsc/boot/filter/config_arm64.go +++ b/runsc/boot/filter/config_arm64.go @@ -16,6 +16,29 @@ package filter -// Reserve for future customization. +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + func init() { + allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + { + seccomp.EqualTo( + syscall.CLONE_VM | + syscall.CLONE_FS | + syscall.CLONE_FILES | + syscall.CLONE_SIGHAND | + syscall.CLONE_SYSVSEM | + syscall.CLONE_THREAD), + seccomp.MatchAny{}, // newsp + // These arguments are left uninitialized by the Go + // runtime, so they may be anything (and are unused by + // the host). + seccomp.MatchAny{}, // parent_tidptr + seccomp.MatchAny{}, // tls + seccomp.MatchAny{}, // child_tidptr + }, + } } diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go index 194952a7b..7b8669595 100644 --- a/runsc/boot/filter/config_profile.go +++ b/runsc/boot/filter/config_profile.go @@ -25,9 +25,9 @@ func profileFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ syscall.SYS_OPENAT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), }, }, } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index e83584b82..ddf288456 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -29,10 +29,12 @@ import ( _ "gvisor.dev/gvisor/pkg/sentry/fs/sys" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tty" + "gvisor.dev/gvisor/pkg/sentry/vfs" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/gofer" @@ -47,6 +49,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -65,7 +68,7 @@ const ( // tmpfs has some extra supported options that we must pass through. var tmpfsAllowedData = []string{"mode", "uid", "gid"} -func addOverlay(ctx context.Context, conf *Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) { +func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) { // Upper layer uses the same flags as lower, but it must be read-write. upperFlags := lowerFlags upperFlags.ReadOnly = false @@ -155,7 +158,7 @@ func compileMounts(spec *specs.Spec) []specs.Mount { } // p9MountData creates a slice of p9 mount data. -func p9MountData(fd int, fa FileAccessType, vfs2 bool) []string { +func p9MountData(fd int, fa config.FileAccessType, vfs2 bool) []string { opts := []string{ "trans=fd", "rfdno=" + strconv.Itoa(fd), @@ -166,7 +169,7 @@ func p9MountData(fd int, fa FileAccessType, vfs2 bool) []string { // enablement. opts = append(opts, "privateunixsocket=true") } - if fa == FileAccessShared { + if fa == config.FileAccessShared { opts = append(opts, "cache=remote_revalidating") } return opts @@ -251,7 +254,7 @@ func mustFindFilesystem(name string) fs.Filesystem { // addSubmountOverlay overlays the inode over a ramfs tree containing the given // paths. -func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string) (*fs.Inode, error) { +func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string, mf fs.MountSourceFlags) (*fs.Inode, error) { // Construct a ramfs tree of mount points. The contents never // change, so this can be fully caching. There's no real // filesystem backing this tree, so we set the filesystem to @@ -261,7 +264,7 @@ func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string if err != nil { return nil, fmt.Errorf("creating mount tree: %v", err) } - overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, fs.MountSourceFlags{}) + overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, mf) if err != nil { return nil, fmt.Errorf("adding mount overlay: %v", err) } @@ -280,7 +283,7 @@ func subtargets(root string, mnts []specs.Mount) []string { return targets } -func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { +func setupContainerFS(ctx context.Context, conf *config.Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { if conf.VFS2 { return setupContainerVFS2(ctx, conf, mntr, procArgs) } @@ -318,14 +321,14 @@ func adjustDirentCache(k *kernel.Kernel) error { } type fdDispenser struct { - fds []int + fds []*fd.FD } func (f *fdDispenser) remove() int { if f.empty() { panic("fdDispenser out of fds") } - rv := f.fds[0] + rv := f.fds[0].Release() f.fds = f.fds[1:] return rv } @@ -390,6 +393,10 @@ type mountHint struct { // root is the inode where the volume is mounted. For mounts with 'pod' share // the volume is mounted once and then bind mounted inside the containers. root *fs.Inode + + // vfsMount is the master mount for the volume. For mounts with 'pod' share + // the master volume is bind mounted inside the containers. + vfsMount *vfs.Mount } func (m *mountHint) setField(key, val string) error { @@ -447,27 +454,27 @@ func (m *mountHint) isSupported() bool { func (m *mountHint) checkCompatible(mount specs.Mount) error { // Remove options that don't affect to mount's behavior. masterOpts := filterUnsupportedOptions(m.mount) - slaveOpts := filterUnsupportedOptions(mount) + replicaOpts := filterUnsupportedOptions(mount) - if len(masterOpts) != len(slaveOpts) { - return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts) + if len(masterOpts) != len(replicaOpts) { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, replicaOpts) } sort.Strings(masterOpts) - sort.Strings(slaveOpts) + sort.Strings(replicaOpts) for i, opt := range masterOpts { - if opt != slaveOpts[i] { - return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts) + if opt != replicaOpts[i] { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, replicaOpts) } } return nil } -func (m *mountHint) fileAccessType() FileAccessType { +func (m *mountHint) fileAccessType() config.FileAccessType { if m.share == container { - return FileAccessExclusive + return config.FileAccessExclusive } - return FileAccessShared + return config.FileAccessShared } func filterUnsupportedOptions(mount specs.Mount) []string { @@ -558,7 +565,7 @@ type containerMounter struct { hints *podMountHints } -func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hints *podMountHints) *containerMounter { +func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints) *containerMounter { return &containerMounter{ root: spec.Root, mounts: compileMounts(spec), @@ -571,9 +578,9 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin // processHints processes annotations that container hints about how volumes // should be mounted (e.g. a volume shared between containers). It must be // called for the root container only. -func (c *containerMounter) processHints(conf *Config) error { +func (c *containerMounter) processHints(conf *config.Config, creds *auth.Credentials) error { if conf.VFS2 { - return nil + return c.processHintsVFS2(conf, creds) } ctx := c.k.SupervisorContext() for _, hint := range c.hints.mounts { @@ -595,7 +602,7 @@ func (c *containerMounter) processHints(conf *Config) error { // setupFS is used to set up the file system for all containers. This is the // main entry point method, with most of the other being internal only. It // returns the mount namespace that is created for the container. -func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) { +func (c *containerMounter) setupFS(conf *config.Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) { log.Infof("Configuring container's file system") // Create context with root credentials to mount the filesystem (the current @@ -621,7 +628,7 @@ func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessA return mns, nil } -func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Config) (*fs.MountNamespace, error) { +func (c *containerMounter) createMountNamespace(ctx context.Context, conf *config.Config) (*fs.MountNamespace, error) { rootInode, err := c.createRootMount(ctx, conf) if err != nil { return nil, fmt.Errorf("creating filesystem for container: %v", err) @@ -633,9 +640,9 @@ func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Confi return mns, nil } -func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error { +func (c *containerMounter) mountSubmounts(ctx context.Context, conf *config.Config, mns *fs.MountNamespace) error { root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, m := range c.mounts { log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options) @@ -669,7 +676,7 @@ func (c *containerMounter) checkDispenser() error { // mountSharedMaster mounts the master of a volume that is shared among // containers in a pod. It returns the root mount's inode. -func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *Config, hint *mountHint) (*fs.Inode, error) { +func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *config.Config, hint *mountHint) (*fs.Inode, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, hint.mount) @@ -709,7 +716,7 @@ func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *Config, } // createRootMount creates the root filesystem. -func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (*fs.Inode, error) { +func (c *containerMounter) createRootMount(ctx context.Context, conf *config.Config) (*fs.Inode, error) { // First construct the filesystem from the spec.Root. mf := fs.MountSourceFlags{ReadOnly: c.root.Readonly || conf.Overlay} @@ -734,7 +741,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (* // for submount paths. "/dev" "/sys" "/proc" and "/tmp" are always // mounted even if they are not in the spec. submounts := append(subtargets("/", c.mounts), "/dev", "/sys", "/proc", "/tmp") - rootInode, err = addSubmountOverlay(ctx, rootInode, submounts) + rootInode, err = addSubmountOverlay(ctx, rootInode, submounts, mf) if err != nil { return nil, fmt.Errorf("adding submount overlay: %v", err) } @@ -754,7 +761,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (* // getMountNameAndOptions retrieves the fsName, opts, and useOverlay values // used for mounts. -func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (string, []string, bool, error) { +func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.Mount) (string, []string, bool, error) { var ( fsName string opts []string @@ -788,19 +795,19 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( return fsName, opts, useOverlay, nil } -func (c *containerMounter) getMountAccessType(mount specs.Mount) FileAccessType { +func (c *containerMounter) getMountAccessType(mount specs.Mount) config.FileAccessType { if hint := c.hints.findMount(mount); hint != nil { return hint.fileAccessType() } // Non-root bind mounts are always shared if no hints were provided. - return FileAccessShared + return config.FileAccessShared } // mountSubmount mounts volumes inside the container's root. Because mounts may // be readonly, a lower ramfs overlay is added to create the mount point dir. // Another overlay is added with tmpfs on top if Config.Overlay is true. // 'm.Destination' must be an absolute path with '..' and symlinks resolved. -func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error { +func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) @@ -844,7 +851,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns submounts := subtargets(m.Destination, c.mounts) if len(submounts) > 0 { log.Infof("Adding submount overlay over %q", m.Destination) - inode, err = addSubmountOverlay(ctx, inode, submounts) + inode, err = addSubmountOverlay(ctx, inode, submounts, mf) if err != nil { return fmt.Errorf("adding submount overlay: %v", err) } @@ -863,7 +870,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns if err != nil { return fmt.Errorf("can't find mount destination %q: %v", m.Destination, err) } - defer dirent.DecRef() + defer dirent.DecRef(ctx) if err := mns.Mount(ctx, dirent, inode); err != nil { return fmt.Errorf("mount %q error: %v", m.Destination, err) } @@ -884,12 +891,12 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun if err != nil { return fmt.Errorf("can't find mount destination %q: %v", mount.Destination, err) } - defer target.DecRef() + defer target.DecRef(ctx) // Take a ref on the inode that is about to be (re)-mounted. source.root.IncRef() if err := mns.Mount(ctx, target, source.root); err != nil { - source.root.DecRef() + source.root.DecRef(ctx) return fmt.Errorf("bind mount %q error: %v", mount.Destination, err) } @@ -899,7 +906,7 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun // addRestoreMount adds a mount to the MountSources map used for restoring a // checkpointed container. -func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnvironment, m specs.Mount) error { +func (c *containerMounter) addRestoreMount(conf *config.Config, renv *fs.RestoreEnvironment, m specs.Mount) error { fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) if err != nil { return err @@ -924,7 +931,7 @@ func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnviron // createRestoreEnvironment builds a fs.RestoreEnvironment called renv by adding // the mounts to the environment. -func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEnvironment, error) { +func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.RestoreEnvironment, error) { renv := &fs.RestoreEnvironment{ MountSources: make(map[string][]fs.MountArgs), } @@ -979,7 +986,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEn // // Note that when there are submounts inside of '/tmp', directories for the // mount points must be present, making '/tmp' not empty anymore. -func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent) error { +func (c *containerMounter) mountTmp(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent) error { for _, m := range c.mounts { if filepath.Clean(m.Destination) == "/tmp" { log.Debugf("Explict %q mount found, skipping internal tmpfs, mount: %+v", "/tmp", m) @@ -992,12 +999,12 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.M switch err { case nil: // Found '/tmp' in filesystem, check if it's empty. - defer tmp.DecRef() + defer tmp.DecRef(ctx) f, err := tmp.Inode.GetFile(ctx, tmp, fs.FileFlags{Read: true, Directory: true}) if err != nil { return err } - defer f.DecRef() + defer f.DecRef(ctx) serializer := &fs.CollectEntriesSerializer{} if err := f.Readdir(ctx, serializer); err != nil { return err diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index 912037075..e986231e5 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -20,6 +20,7 @@ import ( "testing" specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/runsc/config" ) func TestPodMountHintsHappy(t *testing.T) { @@ -196,7 +197,7 @@ func TestGetMountAccessType(t *testing.T) { for _, tst := range []struct { name string annotations map[string]string - want FileAccessType + want config.FileAccessType }{ { name: "container=exclusive", @@ -205,7 +206,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "container", }, - want: FileAccessExclusive, + want: config.FileAccessExclusive, }, { name: "pod=shared", @@ -214,7 +215,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "pod", }, - want: FileAccessShared, + want: config.FileAccessShared, }, { name: "shared=shared", @@ -223,7 +224,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "shared", }, - want: FileAccessShared, + want: config.FileAccessShared, }, { name: "default=shared", @@ -232,7 +233,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "container", }, - want: FileAccessShared, + want: config.FileAccessShared, }, } { t.Run(tst.name, func(t *testing.T) { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index b5df1deb9..dee2c4fbb 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -16,22 +16,25 @@ package boot import ( + "errors" "fmt" mrand "math/rand" "os" "runtime" "sync/atomic" - "syscall" gtime "time" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/memutil" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fdimport" @@ -66,7 +69,9 @@ import ( "gvisor.dev/gvisor/runsc/boot/filter" _ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms. "gvisor.dev/gvisor/runsc/boot/pprof" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" + "gvisor.dev/gvisor/runsc/specutils/seccomp" // Include supported socket providers. "gvisor.dev/gvisor/pkg/sentry/socket/hostinet" @@ -77,6 +82,22 @@ import ( _ "gvisor.dev/gvisor/pkg/sentry/socket/unix" ) +type containerInfo struct { + conf *config.Config + + // spec is the base configuration for the root container. + spec *specs.Spec + + // procArgs refers to the container's init task. + procArgs kernel.CreateProcessArgs + + // stdioFDs contains stdin, stdout, and stderr. + stdioFDs []*fd.FD + + // goferFDs are the FDs that attach the sandbox to the gofers. + goferFDs []*fd.FD +} + // Loader keeps state needed to start the kernel and run the container.. type Loader struct { // k is the kernel. @@ -85,22 +106,11 @@ type Loader struct { // ctrl is the control server. ctrl *controller - conf *Config - - // console is set to true if terminal is enabled. - console bool + // root contains information about the root container in the sandbox. + root containerInfo watchdog *watchdog.Watchdog - // stdioFDs contains stdin, stdout, and stderr. - stdioFDs []int - - // goferFDs are the FDs that attach the sandbox to the gofers. - goferFDs []int - - // spec is the base configuration for the root container. - spec *specs.Spec - // stopSignalForwarding disables forwarding of signals to the sandboxed // container. It should be called when a sandbox is destroyed. stopSignalForwarding func() @@ -108,9 +118,6 @@ type Loader struct { // restore is set to true if we are restoring a container. restore bool - // rootProcArgs refers to the root sandbox init task. - rootProcArgs kernel.CreateProcessArgs - // sandboxID is the ID for the whole sandbox. sandboxID string @@ -162,7 +169,7 @@ type Args struct { // Spec is the sandbox specification. Spec *specs.Spec // Conf is the system configuration. - Conf *Config + Conf *config.Config // ControllerFD is the FD to the URPC controller. The Loader takes ownership // of this FD and may close it at any time. ControllerFD int @@ -175,8 +182,6 @@ type Args struct { // StdioFDs is the stdio for the application. The Loader takes ownership of // these FDs and may close them at any time. StdioFDs []int - // Console is set to true if using TTY. - Console bool // NumCPU is the number of CPUs to create inside the sandbox. NumCPU int // TotalMem is the initial amount of total memory to report back to the @@ -187,7 +192,7 @@ type Args struct { } // make sure stdioFDs are always the same on initial start and on restore -const startingStdioFD = 64 +const startingStdioFD = 256 // New initializes a new kernel loader configured by spec. // New also handles setting up a kernel for restoring a container. @@ -205,6 +210,10 @@ func New(args Args) (*Loader, error) { // Is this a VFSv2 kernel? if args.Conf.VFS2 { kernel.VFS2Enabled = true + if args.Conf.FUSE { + kernel.FUSEEnabled = true + } + vfs2.Override() } @@ -227,9 +236,7 @@ func New(args Args) (*Loader, error) { // Create VDSO. // // Pass k as the platform since it is savable, unlike the actual platform. - // - // FIXME(b/109889800): Use non-nil context. - vdso, err := loader.PrepareVDSO(nil, k) + vdso, err := loader.PrepareVDSO(k) if err != nil { return nil, fmt.Errorf("creating vdso: %v", err) } @@ -275,6 +282,7 @@ func New(args Args) (*Loader, error) { args.NumCPU = runtime.NumCPU() } log.Infof("CPUs: %d", args.NumCPU) + runtime.GOMAXPROCS(args.NumCPU) if args.TotalMem > 0 { // Adjust the total memory returned by the Sentry so that applications that @@ -300,6 +308,12 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("initializing kernel: %v", err) } + if kernel.VFS2Enabled { + if err := registerFilesystems(k); err != nil { + return nil, fmt.Errorf("registering filesystems: %w", err) + } + } + if err := adjustDirentCache(k); err != nil { return nil, err } @@ -318,7 +332,7 @@ func New(args Args) (*Loader, error) { dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction dog := watchdog.New(k, dogOpts) - procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace()) + procArgs, err := createProcessArgs(args.ID, args.Spec, creds, k, k.RootPIDNamespace()) if err != nil { return nil, fmt.Errorf("creating init process for root container: %v", err) } @@ -338,7 +352,7 @@ func New(args Args) (*Loader, error) { if err != nil { return nil, fmt.Errorf("failed to create hostfs filesystem: %v", err) } - defer hostFilesystem.DecRef() + defer hostFilesystem.DecRef(k.SupervisorContext()) hostMount, err := k.VFS().NewDisconnectedMount(hostFilesystem, nil, &vfs.MountOptions{}) if err != nil { return nil, fmt.Errorf("failed to create hostfs mount: %v", err) @@ -346,37 +360,45 @@ func New(args Args) (*Loader, error) { k.SetHostMount(hostMount) } + info := containerInfo{ + conf: args.Conf, + spec: args.Spec, + procArgs: procArgs, + } + // Make host FDs stable between invocations. Host FDs must map to the exact // same number when the sandbox is restored. Otherwise the wrong FD will be // used. - var stdioFDs []int newfd := startingStdioFD - for _, fd := range args.StdioFDs { - err := syscall.Dup3(fd, newfd, syscall.O_CLOEXEC) - if err != nil { - return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err) + for _, stdioFD := range args.StdioFDs { + // Check that newfd is unused to avoid clobbering over it. + if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) { + if err != nil { + return nil, fmt.Errorf("error checking for FD (%d) conflict: %w", newfd, err) + } + return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd) } - stdioFDs = append(stdioFDs, newfd) - err = syscall.Close(fd) + + err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC) if err != nil { - return nil, fmt.Errorf("close original stdioFDs failed: %v", err) + return nil, fmt.Errorf("dup3 of stdios failed: %w", err) } + info.stdioFDs = append(info.stdioFDs, fd.New(newfd)) + _ = unix.Close(stdioFD) newfd++ } + for _, goferFD := range args.GoferFDs { + info.goferFDs = append(info.goferFDs, fd.New(goferFD)) + } eid := execID{cid: args.ID} l := &Loader{ - k: k, - conf: args.Conf, - console: args.Console, - watchdog: dog, - spec: args.Spec, - goferFDs: args.GoferFDs, - stdioFDs: stdioFDs, - rootProcArgs: procArgs, - sandboxID: args.ID, - processes: map[execID]*execProcess{eid: {}}, - mountHints: mountHints, + k: k, + watchdog: dog, + sandboxID: args.ID, + processes: map[execID]*execProcess{eid: {}}, + mountHints: mountHints, + root: info, } // We don't care about child signals; some platforms can generate a @@ -404,8 +426,8 @@ func New(args Args) (*Loader, error) { return l, nil } -// newProcess creates a process that can be run with kernel.CreateProcess. -func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) { +// createProcessArgs creates args that can be used with kernel.CreateProcess. +func createProcessArgs(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) { // Create initial limits. ls, err := createLimitSet(spec) if err != nil { @@ -449,9 +471,19 @@ func (l *Loader) Destroy() { l.stopSignalForwarding() } l.watchdog.Stop() + + // In the success case, stdioFDs and goferFDs will only contain + // released/closed FDs that ownership has been passed over to host FDs and + // gofer sessions. Close them here in case on failure. + for _, fd := range l.root.stdioFDs { + _ = fd.Close() + } + for _, fd := range l.root.goferFDs { + _ = fd.Close() + } } -func createPlatform(conf *Config, deviceFile *os.File) (platform.Platform, error) { +func createPlatform(conf *config.Config, deviceFile *os.File) (platform.Platform, error) { p, err := platform.Lookup(conf.Platform) if err != nil { panic(fmt.Sprintf("invalid platform %v: %v", conf.Platform, err)) @@ -478,14 +510,15 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) { return mf, nil } +// installSeccompFilters installs sandbox seccomp filters with the host. func (l *Loader) installSeccompFilters() error { - if l.conf.DisableSeccomp { + if l.root.conf.DisableSeccomp { filter.Report("syscall filter is DISABLED. Running in less secure mode.") } else { opts := filter.Options{ Platform: l.k.Platform, - HostNetwork: l.conf.Network == NetworkHost, - ProfileEnable: l.conf.ProfileEnable, + HostNetwork: l.root.conf.Network == config.NetworkHost, + ProfileEnable: l.root.conf.ProfileEnable, ControllerFD: l.ctrl.srv.FD(), } if err := filter.Install(opts); err != nil { @@ -511,7 +544,7 @@ func (l *Loader) Run() error { } func (l *Loader) run() error { - if l.conf.Network == NetworkHost { + if l.root.conf.Network == config.NetworkHost { // Delay host network configuration to this point because network namespace // is configured after the loader is created and before Run() is called. log.Debugf("Configuring host network") @@ -532,10 +565,8 @@ func (l *Loader) run() error { // If we are restoring, we do not want to create a process. // l.restore is set by the container manager when a restore call is made. - var ttyFile *host.TTYFileOperations - var ttyFileVFS2 *hostvfs2.TTYFileDescription if !l.restore { - if l.conf.ProfileEnable { + if l.root.conf.ProfileEnable { pprof.Initialize() } @@ -545,82 +576,30 @@ func (l *Loader) run() error { return err } - // Create the FD map, which will set stdin, stdout, and stderr. If console - // is true, then ioctl calls will be passed through to the host fd. - ctx := l.rootProcArgs.NewContext(l.k) - var err error - - // CreateProcess takes a reference on FDMap if successful. We won't need - // ours either way. - l.rootProcArgs.FDTable, ttyFile, ttyFileVFS2, err = createFDTable(ctx, l.console, l.stdioFDs) - if err != nil { - return fmt.Errorf("importing fds: %v", err) - } - - // Setup the root container file system. - l.startGoferMonitor(l.sandboxID, l.goferFDs) - - mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints) - if err := mntr.processHints(l.conf); err != nil { - return err - } - if err := setupContainerFS(ctx, l.conf, mntr, &l.rootProcArgs); err != nil { - return err - } - - // Add the HOME enviroment variable if it is not already set. - var envv []string - if kernel.VFS2Enabled { - envv, err = user.MaybeAddExecUserHomeVFS2(ctx, l.rootProcArgs.MountNamespaceVFS2, - l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) - - } else { - envv, err = user.MaybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, - l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) - } - if err != nil { - return err - } - l.rootProcArgs.Envv = envv - // Create the root container init task. It will begin running // when the kernel is started. - if _, _, err := l.k.CreateProcess(l.rootProcArgs); err != nil { - return fmt.Errorf("creating init process: %v", err) + if _, err := l.createContainerProcess(true, l.sandboxID, &l.root, ep); err != nil { + return err } - // CreateProcess takes a reference on FDTable if successful. - l.rootProcArgs.FDTable.DecRef() } ep.tg = l.k.GlobalInit() - if ns, ok := specutils.GetNS(specs.PIDNamespace, l.spec); ok { + if ns, ok := specutils.GetNS(specs.PIDNamespace, l.root.spec); ok { ep.pidnsPath = ns.Path } - if l.console { - // Set the foreground process group on the TTY to the global init process - // group, since that is what we are about to start running. - switch { - case ttyFileVFS2 != nil: - ep.ttyVFS2 = ttyFileVFS2 - ttyFileVFS2.InitForegroundProcessGroup(ep.tg.ProcessGroup()) - case ttyFile != nil: - ep.tty = ttyFile - ttyFile.InitForegroundProcessGroup(ep.tg.ProcessGroup()) - } - } // Handle signals by forwarding them to the root container process // (except for panic signal, which should cause a panic). l.stopSignalForwarding = sighandling.StartSignalForwarding(func(sig linux.Signal) { // Panic signal should cause a panic. - if l.conf.PanicSignal != -1 && sig == linux.Signal(l.conf.PanicSignal) { + if l.root.conf.PanicSignal != -1 && sig == linux.Signal(l.root.conf.PanicSignal) { panic("Signal-induced panic") } // Otherwise forward to root container. deliveryMode := DeliverToProcess - if l.console { + if l.root.spec.Process.Terminal { // Since we are running with a console, we should forward the signal to // the foreground process group so that job control signals like ^C can // be handled properly. @@ -632,19 +611,6 @@ func (l *Loader) run() error { } }) - // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again - // either in createFDTable() during initial start or in descriptor.initAfterLoad() - // during restore, we can release l.stdioFDs now. VFS2 takes ownership of the - // passed FDs, so only close for VFS1. - if !kernel.VFS2Enabled { - for _, fd := range l.stdioFDs { - err := syscall.Close(fd) - if err != nil { - return fmt.Errorf("close dup()ed stdioFDs: %v", err) - } - } - } - log.Infof("Process should have started...") l.watchdog.Start() return l.k.Start() @@ -664,9 +630,9 @@ func (l *Loader) createContainer(cid string) error { } // startContainer starts a child container. It returns the thread group ID of -// the newly created process. Caller owns 'files' and may close them after -// this method returns. -func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, files []*os.File) error { +// the newly created process. Used FDs are either closed or released. It's safe +// for the caller to close any remaining files upon return. +func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid string, files []*fd.FD) error { // Create capabilities. caps, err := specutils.Capabilities(conf.EnableRaw, spec.Process.Capabilities) if err != nil { @@ -676,8 +642,8 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file l.mu.Lock() defer l.mu.Unlock() - eid := execID{cid: cid} - if _, ok := l.processes[eid]; !ok { + ep := l.processes[execID{cid: cid}] + if ep == nil { return fmt.Errorf("trying to start a deleted container %q", cid) } @@ -711,88 +677,136 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file if pidns == nil { pidns = l.k.RootPIDNamespace().NewChild(l.k.RootUserNamespace()) } - l.processes[eid].pidnsPath = ns.Path + ep.pidnsPath = ns.Path } else { pidns = l.k.RootPIDNamespace() } - procArgs, err := newProcess(cid, spec, creds, l.k, pidns) + + info := &containerInfo{ + conf: conf, + spec: spec, + stdioFDs: files[:3], + goferFDs: files[3:], + } + info.procArgs, err = createProcessArgs(cid, spec, creds, l.k, pidns) if err != nil { return fmt.Errorf("creating new process: %v", err) } + tg, err := l.createContainerProcess(false, cid, info, ep) + if err != nil { + return err + } + + // Success! + l.k.StartProcess(tg) + ep.tg = tg + return nil +} - // setupContainerFS() dups stdioFDs, so we don't need to dup them here. - var stdioFDs []int - for _, f := range files[:3] { - stdioFDs = append(stdioFDs, int(f.Fd())) +func (l *Loader) createContainerProcess(root bool, cid string, info *containerInfo, ep *execProcess) (*kernel.ThreadGroup, error) { + console := false + if root { + // Only root container supports terminal for now. + console = info.spec.Process.Terminal } // Create the FD map, which will set stdin, stdout, and stderr. - ctx := procArgs.NewContext(l.k) - fdTable, _, _, err := createFDTable(ctx, false, stdioFDs) + ctx := info.procArgs.NewContext(l.k) + fdTable, ttyFile, ttyFileVFS2, err := createFDTable(ctx, console, info.stdioFDs) if err != nil { - return fmt.Errorf("importing fds: %v", err) - } - // CreateProcess takes a reference on fdTable if successful. We won't - // need ours either way. - procArgs.FDTable = fdTable - - // Can't take ownership away from os.File. dup them to get a new FDs. - var goferFDs []int - for _, f := range files[3:] { - fd, err := syscall.Dup(int(f.Fd())) - if err != nil { - return fmt.Errorf("failed to dup file: %v", err) - } - goferFDs = append(goferFDs, fd) + return nil, fmt.Errorf("importing fds: %v", err) } + // CreateProcess takes a reference on fdTable if successful. We won't need + // ours either way. + info.procArgs.FDTable = fdTable // Setup the child container file system. - l.startGoferMonitor(cid, goferFDs) + l.startGoferMonitor(cid, info.goferFDs) - mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints) - if err := setupContainerFS(ctx, conf, mntr, &procArgs); err != nil { - return err + mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints) + if root { + if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil { + return nil, err + } + } + if err := setupContainerFS(ctx, info.conf, mntr, &info.procArgs); err != nil { + return nil, err } // Add the HOME enviroment variable if it is not already set. var envv []string if kernel.VFS2Enabled { - envv, err = user.MaybeAddExecUserHomeVFS2(ctx, procArgs.MountNamespaceVFS2, - procArgs.Credentials.RealKUID, procArgs.Envv) + envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2, + info.procArgs.Credentials.RealKUID, info.procArgs.Envv) } else { - envv, err = user.MaybeAddExecUserHome(ctx, procArgs.MountNamespace, - procArgs.Credentials.RealKUID, procArgs.Envv) + envv, err = user.MaybeAddExecUserHome(ctx, info.procArgs.MountNamespace, + info.procArgs.Credentials.RealKUID, info.procArgs.Envv) } if err != nil { - return err + return nil, err } - procArgs.Envv = envv + info.procArgs.Envv = envv // Create and start the new process. - tg, _, err := l.k.CreateProcess(procArgs) + tg, _, err := l.k.CreateProcess(info.procArgs) if err != nil { - return fmt.Errorf("creating process: %v", err) + return nil, fmt.Errorf("creating process: %v", err) } - l.k.StartProcess(tg) - // CreateProcess takes a reference on FDTable if successful. - procArgs.FDTable.DecRef() + info.procArgs.FDTable.DecRef(ctx) - l.processes[eid].tg = tg - return nil + // Set the foreground process group on the TTY to the global init process + // group, since that is what we are about to start running. + if root { + switch { + case ttyFileVFS2 != nil: + ep.ttyVFS2 = ttyFileVFS2 + ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup()) + case ttyFile != nil: + ep.tty = ttyFile + ttyFile.InitForegroundProcessGroup(tg.ProcessGroup()) + } + } + + // Install seccomp filters with the new task if there are any. + if info.conf.OCISeccomp { + if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil { + program, err := seccomp.BuildProgram(info.spec.Linux.Seccomp) + if err != nil { + return nil, fmt.Errorf("building seccomp program: %v", err) + } + + if log.IsLogging(log.Debug) { + out, _ := bpf.DecodeProgram(program) + log.Debugf("Installing OCI seccomp filters\nProgram:\n%s", out) + } + + task := tg.Leader() + // NOTE: It seems Flags are ignored by runc so we ignore them too. + if err := task.AppendSyscallFilter(program, true); err != nil { + return nil, fmt.Errorf("appending seccomp filters: %v", err) + } + } + } else { + if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil { + log.Warningf("Seccomp spec is being ignored") + } + } + + return tg, nil } // startGoferMonitor runs a goroutine to monitor gofer's health. It polls on -// the gofer FDs looking for disconnects, and destroys the container if a +// the gofer FDs looking for disconnects, and kills the container processes if a // disconnect occurs in any of the gofer FDs. -func (l *Loader) startGoferMonitor(cid string, goferFDs []int) { +func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) { go func() { log.Debugf("Monitoring gofer health for container %q", cid) var events []unix.PollFd - for _, fd := range goferFDs { + for _, goferFD := range goferFDs { events = append(events, unix.PollFd{ - Fd: int32(fd), + Fd: int32(goferFD.FD()), Events: unix.POLLHUP | unix.POLLRDHUP, }) } @@ -805,18 +819,15 @@ func (l *Loader) startGoferMonitor(cid string, goferFDs []int) { panic(fmt.Sprintf("Error monitoring gofer FDs: %v", err)) } - // Check if the gofer has stopped as part of normal container destruction. - // This is done just to avoid sending an annoying error message to the log. - // Note that there is a small race window in between mu.Unlock() and the - // lock being reacquired in destroyContainer(), but it's harmless to call - // destroyContainer() multiple times. l.mu.Lock() - _, ok := l.processes[execID{cid: cid}] - l.mu.Unlock() - if ok { - log.Infof("Gofer socket disconnected, destroying container %q", cid) - if err := l.destroyContainer(cid); err != nil { - log.Warningf("Error destroying container %q after gofer stopped: %v", cid, err) + defer l.mu.Unlock() + + // The gofer could have been stopped due to a normal container shutdown. + // Check if the container has not stopped yet. + if tg, _ := l.tryThreadGroupFromIDLocked(execID{cid: cid}); tg != nil { + log.Infof("Gofer socket disconnected, killing container %q", cid) + if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil { + log.Warningf("Error killing container %q after gofer stopped: %v", cid, err) } } }() @@ -885,37 +896,42 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) { return 0, fmt.Errorf("container %q not started", args.ContainerID) } - // Get the container MountNamespace from the Task. + // Get the container MountNamespace from the Task. Try to acquire ref may fail + // in case it raced with task exit. if kernel.VFS2Enabled { // task.MountNamespace() does not take a ref, so we must do so ourselves. args.MountNamespaceVFS2 = tg.Leader().MountNamespaceVFS2() - args.MountNamespaceVFS2.IncRef() + if !args.MountNamespaceVFS2.TryIncRef() { + return 0, fmt.Errorf("container %q has stopped", args.ContainerID) + } } else { + var reffed bool tg.Leader().WithMuLocked(func(t *kernel.Task) { // task.MountNamespace() does not take a ref, so we must do so ourselves. args.MountNamespace = t.MountNamespace() - args.MountNamespace.IncRef() + reffed = args.MountNamespace.TryIncRef() }) + if !reffed { + return 0, fmt.Errorf("container %q has stopped", args.ContainerID) + } } // Add the HOME environment variable if it is not already set. if kernel.VFS2Enabled { - defer args.MountNamespaceVFS2.DecRef() - root := args.MountNamespaceVFS2.Root() - defer root.DecRef() ctx := vfs.WithRoot(l.k.SupervisorContext(), root) + defer args.MountNamespaceVFS2.DecRef(ctx) + defer root.DecRef(ctx) envv, err := user.MaybeAddExecUserHomeVFS2(ctx, args.MountNamespaceVFS2, args.KUID, args.Envv) if err != nil { return 0, err } args.Envv = envv } else { - defer args.MountNamespace.DecRef() - root := args.MountNamespace.Root() - defer root.DecRef() ctx := fs.WithRoot(l.k.SupervisorContext(), root) + defer args.MountNamespace.DecRef(ctx) + defer root.DecRef(ctx) envv, err := user.MaybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv) if err != nil { return 0, err @@ -1012,20 +1028,25 @@ func (l *Loader) WaitExit() kernel.ExitStatus { // Wait for container. l.k.WaitExited() + // Cleanup + l.ctrl.stop() + + refs.OnExit() + return l.k.GlobalInit().ExitStatus() } -func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) { +func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) { // Create an empty network stack because the network namespace may be empty at // this point. Netns is configured before Run() is called. Netstack is // configured using a control uRPC message. Host network is configured inside // Run(). switch conf.Network { - case NetworkHost: + case config.NetworkHost: // No network namespacing support for hostinet yet, hence creator is nil. return inet.NewRootNamespace(hostinet.NewStack(), nil), nil - case NetworkNone, NetworkSandbox: + case config.NetworkNone, config.NetworkSandbox: s, err := newEmptySandboxNetworkStack(clock, uniqueID) if err != nil { return nil, err @@ -1043,8 +1064,8 @@ func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.Uni } func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4} s := netstack.Stack{stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, @@ -1058,17 +1079,30 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in })} // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { - return nil, fmt.Errorf("failed to enable SACK: %s", err) + { + opt := tcpip.TCPSACKEnabled(true) + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Set default TTLs as required by socket/netstack. - s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) - s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + { + opt := tcpip.DefaultTTLOption(netstack.DefaultTTL) + if err := s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv4.ProtocolNumber, opt, opt, err) + } + if err := s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv6.ProtocolNumber, opt, opt, err) + } + } // Enable Receive Buffer Auto-Tuning. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } return &s, nil @@ -1264,7 +1298,7 @@ func (l *Loader) ttyFromIDLocked(key execID) (*host.TTYFileOperations, *hostvfs2 return ep.tty, ep.ttyVFS2, nil } -func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { +func createFDTable(ctx context.Context, console bool, stdioFDs []*fd.FD) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { if len(stdioFDs) != 3 { return nil, nil, nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs)) } @@ -1273,7 +1307,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F fdTable := k.NewFDTable() ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, console, stdioFDs) if err != nil { - fdTable.DecRef() + fdTable.DecRef(ctx) return nil, nil, nil, err } return fdTable, ttyFile, ttyFileVFS2, nil diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index e448fd773..1f49431a2 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -26,6 +26,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/control/server" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" @@ -34,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/fsgofer" ) @@ -43,15 +45,19 @@ func init() { if err := fsgofer.OpenProcSelfFD(); err != nil { panic(err) } + config.RegisterFlags() } -func testConfig() *Config { - return &Config{ - RootDir: "unused_root_dir", - Network: NetworkNone, - DisableSeccomp: true, - Platform: "ptrace", +func testConfig() *config.Config { + conf, err := config.NewFromFlags() + if err != nil { + panic(err) } + // Change test defaults. + conf.RootDir = "unused_root_dir" + conf.Network = config.NetworkNone + conf.DisableSeccomp = true + return conf } // testSpec returns a simple spec that can be used in tests. @@ -258,7 +264,7 @@ type CreateMountTestcase struct { expectedPaths []string } -func createMountTestcases(vfs2 bool) []*CreateMountTestcase { +func createMountTestcases() []*CreateMountTestcase { testCases := []*CreateMountTestcase{ &CreateMountTestcase{ // Only proc. @@ -403,32 +409,26 @@ func createMountTestcases(vfs2 bool) []*CreateMountTestcase { Destination: "/proc", Type: "tmpfs", }, - // TODO (gvisor.dev/issue/1487): Re-add this case when sysfs supports - // MkDirAt in VFS2 (and remove the reduntant append). - // { - // Destination: "/sys/bar", - // Type: "tmpfs", - // }, - // + { + Destination: "/sys/bar", + Type: "tmpfs", + }, + { Destination: "/tmp/baz", Type: "tmpfs", }, }, }, - expectedPaths: []string{"/proc", "/sys" /* "/sys/bar" ,*/, "/tmp", "/tmp/baz"}, + expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"}, } - if !vfs2 { - vfsCase.spec.Mounts = append(vfsCase.spec.Mounts, specs.Mount{Destination: "/sys/bar", Type: "tmpfs"}) - vfsCase.expectedPaths = append(vfsCase.expectedPaths, "/sys/bar") - } return append(testCases, vfsCase) } // Test that MountNamespace can be created with various specs. func TestCreateMountNamespace(t *testing.T) { - for _, tc := range createMountTestcases(false /* vfs2 */) { + for _, tc := range createMountTestcases() { t.Run(tc.name, func(t *testing.T) { conf := testConfig() ctx := contexttest.Context(t) @@ -439,7 +439,7 @@ func TestCreateMountNamespace(t *testing.T) { } defer cleanup() - mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{}) + mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}) mns, err := mntr.createMountNamespace(ctx, conf) if err != nil { t.Fatalf("failed to create mount namespace: %v", err) @@ -450,13 +450,13 @@ func TestCreateMountNamespace(t *testing.T) { } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range tc.expectedPaths { maxTraversals := uint(0) if d, err := mns.FindInode(ctx, root, root, p, &maxTraversals); err != nil { t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err) } else { - d.DecRef() + d.DecRef(ctx) } } }) @@ -465,7 +465,7 @@ func TestCreateMountNamespace(t *testing.T) { // Test that MountNamespace can be created with various specs. func TestCreateMountNamespaceVFS2(t *testing.T) { - for _, tc := range createMountTestcases(true /* vfs2 */) { + for _, tc := range createMountTestcases() { t.Run(tc.name, func(t *testing.T) { spec := testSpec() spec.Mounts = tc.spec.Mounts @@ -479,19 +479,19 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { defer l.Destroy() defer loaderCleanup() - mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints) - if err := mntr.processHints(l.conf); err != nil { + mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints) + if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil { t.Fatalf("failed process hints: %v", err) } ctx := l.k.SupervisorContext() - mns, err := mntr.setupVFS2(ctx, l.conf, &l.rootProcArgs) + mns, err := mntr.mountAll(l.root.conf, &l.root.procArgs) if err != nil { - t.Fatalf("failed to setupVFS2: %v", err) + t.Fatalf("mountAll: %v", err) } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range tc.expectedPaths { target := &vfs.PathOperation{ Root: root, @@ -499,10 +499,10 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { Path: fspath.Parse(p), } - if d, err := l.k.VFS().GetDentryAt(ctx, l.rootProcArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil { + if d, err := l.k.VFS().GetDentryAt(ctx, l.root.procArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil { t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err) } else { - d.DecRef() + d.DecRef(ctx) } } }) @@ -545,7 +545,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, }, "tmpfs": { @@ -599,7 +599,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, { Dev: "9pfs-/dev/fd-foo", @@ -657,7 +657,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, }, "tmpfs": { @@ -697,7 +697,11 @@ func TestRestoreEnvironment(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conf := testConfig() - mntr := newContainerMounter(tc.spec, tc.ioFDs, nil, &podMountHints{}) + var ioFDs []*fd.FD + for _, ioFD := range tc.ioFDs { + ioFDs = append(ioFDs, fd.New(ioFD)) + } + mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}) actualRenv, err := mntr.createRestoreEnvironment(conf) if !tc.errorExpected && err != nil { t.Fatalf("could not create restore environment for test:%s", tc.name) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 14d2f56a5..988573640 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/packetsocket" "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -32,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" + "gvisor.dev/gvisor/runsc/config" ) var ( @@ -77,44 +79,6 @@ type DefaultRoute struct { Name string } -// QueueingDiscipline is used to specify the kind of Queueing Discipline to -// apply for a give FDBasedLink. -type QueueingDiscipline int - -const ( - // QDiscNone disables any queueing for the underlying FD. - QDiscNone QueueingDiscipline = iota - - // QDiscFIFO applies a simple fifo based queue to the underlying - // FD. - QDiscFIFO -) - -// MakeQueueingDiscipline if possible the equivalent QueuingDiscipline for s -// else returns an error. -func MakeQueueingDiscipline(s string) (QueueingDiscipline, error) { - switch s { - case "none": - return QDiscNone, nil - case "fifo": - return QDiscFIFO, nil - default: - return 0, fmt.Errorf("unsupported qdisc specified: %q", s) - } -} - -// String implements fmt.Stringer. -func (q QueueingDiscipline) String() string { - switch q { - case QDiscNone: - return "none" - case QDiscFIFO: - return "fifo" - default: - panic(fmt.Sprintf("Invalid queueing discipline: %d", q)) - } -} - // FDBasedLink configures an fd-based link. type FDBasedLink struct { Name string @@ -126,7 +90,7 @@ type FDBasedLink struct { TXChecksumOffload bool RXChecksumOffload bool LinkAddress net.HardwareAddr - QDisc QueueingDiscipline + QDisc config.QueueingDiscipline // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -246,12 +210,15 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } switch link.QDisc { - case QDiscNone: - case QDiscFIFO: + case config.QDiscNone: + case config.QDiscFIFO: log.Infof("Enabling FIFO QDisc on %q", link.Name) linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000) } + // Enable support for AF_PACKET sockets to receive outgoing packets. + linkEP = packetsocket.New(linkEP) + log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err diff --git a/runsc/boot/strace.go b/runsc/boot/strace.go index fbfd3b07c..c21648a32 100644 --- a/runsc/boot/strace.go +++ b/runsc/boot/strace.go @@ -15,10 +15,13 @@ package boot import ( + "strings" + "gvisor.dev/gvisor/pkg/sentry/strace" + "gvisor.dev/gvisor/runsc/config" ) -func enableStrace(conf *Config) error { +func enableStrace(conf *config.Config) error { // We must initialize even if strace is not enabled. strace.Initialize() @@ -36,5 +39,5 @@ func enableStrace(conf *Config) error { strace.EnableAll(strace.SinkTypeLog) return nil } - return strace.Enable(conf.StraceSyscalls, strace.SinkTypeLog) + return strace.Enable(strings.Split(conf.StraceSyscalls, ","), strace.SinkTypeLog) } diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index b68117867..e36664938 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -16,12 +16,12 @@ package boot import ( "fmt" - "path" "sort" "strings" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" @@ -37,13 +37,19 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/runsc/config" ) -func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) error { +func registerFilesystems(k *kernel.Kernel) error { + ctx := k.SupervisorContext() + creds := auth.NewRootCredentials(k.RootUserNamespace()) + vfsObj := k.VFS() + vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, // TODO(b/29356795): Users may mount this once the terminals are in a @@ -73,6 +79,10 @@ func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, cre AllowUserMount: true, AllowUserList: true, }) + vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) // Setup files in devtmpfs. if err := memdev.Register(vfsObj); err != nil { @@ -81,18 +91,24 @@ func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, cre if err := ttydev.Register(vfsObj); err != nil { return fmt.Errorf("registering ttydev: %w", err) } - - if err := fuse.Register(vfsObj); err != nil { - return fmt.Errorf("registering fusedev: %w", err) + tunSupported := tundev.IsNetTunSupported(inet.StackFromContext(ctx)) + if tunSupported { + if err := tundev.Register(vfsObj); err != nil { + return fmt.Errorf("registering tundev: %v", err) + } } - if err := tundev.Register(vfsObj); err != nil { - return fmt.Errorf("registering tundev: %v", err) + + if kernel.FUSEEnabled { + if err := fuse.Register(vfsObj); err != nil { + return fmt.Errorf("registering fusedev: %w", err) + } } + a, err := devtmpfs.NewAccessor(ctx, vfsObj, creds, devtmpfs.Name) if err != nil { return fmt.Errorf("creating devtmpfs accessor: %w", err) } - defer a.Release() + defer a.Release(ctx) if err := a.UserspaceInit(ctx); err != nil { return fmt.Errorf("initializing userspace: %w", err) @@ -103,20 +119,23 @@ func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, cre if err := ttydev.CreateDevtmpfsFiles(ctx, a); err != nil { return fmt.Errorf("creating ttydev devtmpfs files: %w", err) } - if err := tundev.CreateDevtmpfsFiles(ctx, a); err != nil { - return fmt.Errorf("creating tundev devtmpfs files: %v", err) + if tunSupported { + if err := tundev.CreateDevtmpfsFiles(ctx, a); err != nil { + return fmt.Errorf("creating tundev devtmpfs files: %v", err) + } } - if err := fuse.CreateDevtmpfsFile(ctx, a); err != nil { - return fmt.Errorf("creating fusedev devtmpfs files: %w", err) + + if kernel.FUSEEnabled { + if err := fuse.CreateDevtmpfsFile(ctx, a); err != nil { + return fmt.Errorf("creating fusedev devtmpfs files: %w", err) + } } + return nil } -func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { - if err := mntr.k.VFS().Init(); err != nil { - return fmt.Errorf("failed to initialize VFS: %w", err) - } - mns, err := mntr.setupVFS2(ctx, conf, procArgs) +func setupContainerVFS2(ctx context.Context, conf *config.Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { + mns, err := mntr.mountAll(conf, procArgs) if err != nil { return fmt.Errorf("failed to setupFS: %w", err) } @@ -131,7 +150,7 @@ func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounte return nil } -func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) { +func (c *containerMounter) mountAll(conf *config.Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) { log.Infof("Configuring container's file system with VFS2") // Create context with root credentials to mount the filesystem (the current @@ -144,36 +163,115 @@ func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals rootCtx := procArgs.NewContext(c.k) - if err := registerFilesystems(rootCtx, c.k.VFS(), rootCreds); err != nil { - return nil, fmt.Errorf("register filesystems: %w", err) - } - mns, err := c.createMountNamespaceVFS2(rootCtx, conf, rootCreds) if err != nil { return nil, fmt.Errorf("creating mount namespace: %w", err) } rootProcArgs.MountNamespaceVFS2 = mns + root := mns.Root() + defer root.DecRef(rootCtx) + if root.Mount().ReadOnly() { + // Switch to ReadWrite while we setup submounts. + if err := c.k.VFS().SetMountReadOnly(root.Mount(), false); err != nil { + return nil, fmt.Errorf(`failed to set mount at "/" readwrite: %w`, err) + } + // Restore back to ReadOnly at the end. + defer func() { + if err := c.k.VFS().SetMountReadOnly(root.Mount(), true); err != nil { + panic(fmt.Sprintf(`failed to restore mount at "/" back to readonly: %v`, err)) + } + }() + } + // Mount submounts. if err := c.mountSubmountsVFS2(rootCtx, conf, mns, rootCreds); err != nil { return nil, fmt.Errorf("mounting submounts vfs2: %w", err) } + return mns, nil } -func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *Config, creds *auth.Credentials) (*vfs.MountNamespace, error) { +// createMountNamespaceVFS2 creates the container's root mount and namespace. +func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *config.Config, creds *auth.Credentials) (*vfs.MountNamespace, error) { fd := c.fds.remove() - opts := strings.Join(p9MountData(fd, conf.FileAccess, true /* vfs2 */), ",") + data := p9MountData(fd, conf.FileAccess, true /* vfs2 */) + + if conf.OverlayfsStaleRead { + // We can't check for overlayfs here because sandbox is chroot'ed and gofer + // can only send mount options for specs.Mounts (specs.Root is missing + // Options field). So assume root is always on top of overlayfs. + data = append(data, "overlayfs_stale_read") + } log.Infof("Mounting root over 9P, ioFD: %d", fd) - mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{Data: opts}) + opts := &vfs.MountOptions{ + ReadOnly: c.root.Readonly, + GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: strings.Join(data, ","), + }, + InternalMount: true, + } + + fsName := gofer.Name + if conf.Overlay && !c.root.Readonly { + log.Infof("Adding overlay on top of root") + var err error + var cleanup func() + opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) + if err != nil { + return nil, fmt.Errorf("mounting root with overlay: %w", err) + } + defer cleanup() + fsName = overlay.Name + } + + mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", fsName, opts) if err != nil { return nil, fmt.Errorf("setting up mount namespace: %w", err) } return mns, nil } -func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error { +// configureOverlay mounts the lower layer using "lowerOpts", mounts the upper +// layer using tmpfs, and return overlay mount options. "cleanup" must be called +// after the options have been used to mount the overlay, to release refs on +// lower and upper mounts. +func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Credentials, lowerOpts *vfs.MountOptions, lowerFSName string) (*vfs.MountOptions, func(), error) { + // First copy options from lower layer to upper layer and overlay. Clear + // filesystem specific options. + upperOpts := *lowerOpts + upperOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{} + + overlayOpts := *lowerOpts + overlayOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{} + + // Next mount upper and lower. Upper is a tmpfs mount to keep all + // modifications inside the sandbox. + upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts) + if err != nil { + return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err) + } + cu := cleanup.Make(func() { upper.DecRef(ctx) }) + defer cu.Clean() + + // All writes go to the upper layer, be paranoid and make lower readonly. + lowerOpts.ReadOnly = true + lower, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, lowerFSName, lowerOpts) + if err != nil { + return nil, nil, err + } + cu.Add(func() { lower.DecRef(ctx) }) + + // Configure overlay with both layers. + overlayOpts.GetFilesystemOptions.InternalData = overlay.FilesystemOptions{ + UpperRoot: vfs.MakeVirtualDentry(upper, upper.Root()), + LowerRoots: []vfs.VirtualDentry{vfs.MakeVirtualDentry(lower, lower.Root())}, + } + return &overlayOpts, cu.Release(), nil +} + +func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials) error { mounts, err := c.prepareMountsVFS2() if err != nil { return err @@ -182,8 +280,34 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, for i := range mounts { submount := &mounts[i] log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options) - if err := c.mountSubmountVFS2(ctx, conf, mns, creds, submount); err != nil { - return err + var ( + mnt *vfs.Mount + err error + ) + + if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() { + mnt, err = c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint) + if err != nil { + return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err) + } + } else { + mnt, err = c.mountSubmountVFS2(ctx, conf, mns, creds, submount) + if err != nil { + return fmt.Errorf("mount submount %q: %w", submount.Destination, err) + } + } + + if mnt != nil && mnt.ReadOnly() { + // Switch to ReadWrite while we setup submounts. + if err := c.k.VFS().SetMountReadOnly(mnt, false); err != nil { + return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.Destination, err) + } + // Restore back to ReadOnly at the end. + defer func() { + if err := c.k.VFS().SetMountReadOnly(mnt, true); err != nil { + panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.Destination, err)) + } + }() } } @@ -227,62 +351,83 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { return mounts, nil } -func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) error { - root := mns.Root() - defer root.DecRef() - target := &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(submount.Destination), - } - fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, submount) +func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) (*vfs.Mount, error) { + fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, submount) if err != nil { - return fmt.Errorf("mountOptions failed: %w", err) + return nil, fmt.Errorf("mountOptions failed: %w", err) } if len(fsName) == 0 { // Filesystem is not supported (e.g. cgroup), just skip it. - return nil + return nil, nil } - if err := c.makeSyntheticMount(ctx, submount.Destination, root, creds); err != nil { - return err + if err := c.makeMountPoint(ctx, creds, mns, submount.Destination); err != nil { + return nil, fmt.Errorf("creating mount point %q: %w", submount.Destination, err) } - if err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts); err != nil { - return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) + + if useOverlay { + log.Infof("Adding overlay on top of mount %q", submount.Destination) + var cleanup func() + opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) + if err != nil { + return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.Destination, err) + } + defer cleanup() + fsName = overlay.Name + } + + root := mns.Root() + defer root.DecRef(ctx) + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(submount.Destination), + } + mnt, err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts) + if err != nil { + return nil, fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) } log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts.GetFilesystemOptions.Data) - return nil + return mnt, nil } // getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values // used for mounts. -func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndFD) (string, *vfs.MountOptions, error) { - var ( - fsName string - data []string - ) +func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mountAndFD) (string, *vfs.MountOptions, bool, error) { + fsName := m.Type + useOverlay := false + var data []string // Find filesystem name and FS specific data field. switch m.Type { case devpts.Name, devtmpfs.Name, proc.Name, sys.Name: - fsName = m.Type + // Nothing to do. + case nonefs: fsName = sys.Name - case tmpfs.Name: - fsName = m.Type + case tmpfs.Name: var err error data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...) if err != nil { - return "", nil, err + return "", nil, false, err } case bind: fsName = gofer.Name + if m.fd == 0 { + // Check that an FD was provided to fails fast. Technically FD=0 is valid, + // but unlikely to be correct in this context. + return "", nil, false, fmt.Errorf("9P mount requires a connection FD") + } data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */) + // If configured, add overlay to all writable mounts. + useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly + default: log.Warningf("ignoring unknown filesystem type %q", m.Type) + return "", nil, false, nil } opts := &vfs.MountOptions{ @@ -307,38 +452,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndF } } - if conf.Overlay { - // All writes go to upper, be paranoid and make lower readonly. - opts.ReadOnly = true - } - return fsName, opts, nil -} - -func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath string, root vfs.VirtualDentry, creds *auth.Credentials) error { - target := &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(currentPath), - } - _, err := c.k.VFS().StatAt(ctx, creds, target, &vfs.StatOptions{}) - if err == nil { - // Mount point exists, nothing else to do. - return nil - } - if err != syserror.ENOENT { - return fmt.Errorf("stat failed for %q during mount point creation: %w", currentPath, err) - } - - // Recurse to ensure parent is created and then create the mount point. - if err := c.makeSyntheticMount(ctx, path.Dir(currentPath), root, creds); err != nil { - return err - } - log.Debugf("Creating dir %q for mount point", currentPath) - mkdirOpts := &vfs.MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true} - if err := c.k.VFS().MkdirAt(ctx, creds, target, mkdirOpts); err != nil { - return fmt.Errorf("failed to create directory %q for mount: %w", currentPath, err) - } - return nil + return fsName, opts, useOverlay, nil } // mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so. @@ -350,7 +464,7 @@ func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath s // // Note that when there are submounts inside of '/tmp', directories for the // mount points must be present, making '/tmp' not empty anymore. -func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds *auth.Credentials, mns *vfs.MountNamespace) error { +func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config, creds *auth.Credentials, mns *vfs.MountNamespace) error { for _, m := range c.mounts { // m.Destination has been cleaned, so it's to use equality here. if m.Destination == "/tmp" { @@ -360,28 +474,35 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) pop := vfs.PathOperation{ Root: root, Start: root, Path: fspath.Parse("/tmp"), } // TODO(gvisor.dev/issue/2782): Use O_PATH when available. - statx, err := c.k.VFS().StatAt(ctx, creds, &pop, &vfs.StatOptions{}) + fd, err := c.k.VFS().OpenAt(ctx, creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY}) switch err { case nil: - // Found '/tmp' in filesystem, check if it's empty. - if linux.FileMode(statx.Mode).FileType() != linux.ModeDirectory { - // Not a dir?! Leave it be. + defer fd.DecRef(ctx) + + err := fd.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { + if dirent.Name != "." && dirent.Name != ".." { + return syserror.ENOTEMPTY + } return nil - } - if statx.Nlink > 2 { + })) + switch err { + case nil: + log.Infof(`Mounting internal tmpfs on top of empty "/tmp"`) + case syserror.ENOTEMPTY: // If more than "." and ".." is found, skip internal tmpfs to prevent // hiding existing files. log.Infof(`Skipping internal tmpfs mount for "/tmp" because it's not empty`) return nil + default: + return err } - log.Infof(`Mounting internal tmpfs on top of empty "/tmp"`) fallthrough case syserror.ENOENT: @@ -394,9 +515,122 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds // another user. This is normally done for /tmp. Options: []string{"mode=01777"}, } - return c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount}) + _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount}) + return err + + case syserror.ENOTDIR: + // Not a dir?! Let it be. + return nil default: - return fmt.Errorf(`stating "/tmp" inside container: %w`, err) + return fmt.Errorf(`opening "/tmp" inside container: %w`, err) + } +} + +// processHintsVFS2 processes annotations that container hints about how volumes +// should be mounted (e.g. a volume shared between containers). It must be +// called for the root container only. +func (c *containerMounter) processHintsVFS2(conf *config.Config, creds *auth.Credentials) error { + ctx := c.k.SupervisorContext() + for _, hint := range c.hints.mounts { + // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a + // common gofer to mount all shared volumes. + if hint.mount.Type != tmpfs.Name { + continue + } + + log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type) + mnt, err := c.mountSharedMasterVFS2(ctx, conf, hint, creds) + if err != nil { + return fmt.Errorf("mounting shared master %q: %v", hint.name, err) + } + hint.vfsMount = mnt + } + return nil +} + +// mountSharedMasterVFS2 mounts the master of a volume that is shared among +// containers in a pod. +func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *config.Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) { + // Map mount type to filesystem name, and parse out the options that we are + // capable of dealing with. + mntFD := &mountAndFD{Mount: hint.mount} + fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, mntFD) + if err != nil { + return nil, err + } + if len(fsName) == 0 { + return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type) + } + + if useOverlay { + log.Infof("Adding overlay on top of shared mount %q", mntFD.Destination) + var cleanup func() + opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) + if err != nil { + return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.Destination, err) + } + defer cleanup() + fsName = overlay.Name + } + + return c.k.VFS().MountDisconnected(ctx, creds, "", fsName, opts) +} + +// mountSharedSubmount binds mount to a previously mounted volume that is shared +// among containers in the same pod. +func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) (*vfs.Mount, error) { + if err := source.checkCompatible(mount); err != nil { + return nil, err + } + + // Ignore data and useOverlay because these were already applied to + // the master mount. + _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount}) + if err != nil { + return nil, err + } + newMnt, err := c.k.VFS().NewDisconnectedMount(source.vfsMount.Filesystem(), source.vfsMount.Root(), opts) + if err != nil { + return nil, err + } + defer newMnt.DecRef(ctx) + + root := mns.Root() + defer root.DecRef(ctx) + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(mount.Destination), + } + + if err := c.makeMountPoint(ctx, creds, mns, mount.Destination); err != nil { + return nil, fmt.Errorf("creating mount point %q: %w", mount.Destination, err) + } + + if err := c.k.VFS().ConnectMountAt(ctx, creds, newMnt, target); err != nil { + return nil, err + } + log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name) + return newMnt, nil +} + +func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, dest string) error { + root := mns.Root() + defer root.DecRef(ctx) + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(dest), + } + // First check if mount point exists. When overlay is enabled, gofer doesn't + // allow changes to the FS, making MakeSytheticMountpoint() ineffective + // because MkdirAt fails with EROFS even if file exists. + vd, err := c.k.VFS().GetDentryAt(ctx, creds, target, &vfs.GetDentryOptions{}) + if err == nil { + // File exists, we're done. + vd.DecRef(ctx) + return nil } + return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds) } diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD index 7e34a284a..37f4253ba 100644 --- a/runsc/cgroup/BUILD +++ b/runsc/cgroup/BUILD @@ -10,7 +10,7 @@ go_library( "//pkg/cleanup", "//pkg/log", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) @@ -22,6 +22,6 @@ go_test( tags = ["local"], deps = [ "//pkg/test/testutil", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index e5cc9d622..8fbc3887a 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -92,7 +92,17 @@ func setOptionalValueUint16(path, name string, val *uint16) error { func setValue(path, name, data string) error { fullpath := filepath.Join(path, name) - return ioutil.WriteFile(fullpath, []byte(data), 0700) + + // Retry writes on EINTR; see: + // https://github.com/golang/go/issues/38033 + for { + err := ioutil.WriteFile(fullpath, []byte(data), 0700) + if err == nil { + return nil + } else if !errors.Is(err, syscall.EINTR) { + return err + } + } } func getValue(path, name string) (string, error) { @@ -132,8 +142,16 @@ func fillFromAncestor(path string) (string, error) { if err != nil { return "", err } - if err := ioutil.WriteFile(path, []byte(val), 0700); err != nil { - return "", err + + // Retry writes on EINTR; see: + // https://github.com/golang/go/issues/38033 + for { + err := ioutil.WriteFile(path, []byte(val), 0700) + if err == nil { + break + } else if !errors.Is(err, syscall.EINTR) { + return "", err + } } return val, nil } diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index dae9b3b3e..2556f6d9e 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/unet", "//pkg/urpc", "//runsc/boot", + "//runsc/config", "//runsc/console", "//runsc/container", "//runsc/flag", @@ -58,7 +59,7 @@ go_library( "//runsc/fsgofer/filter", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], @@ -84,12 +85,12 @@ go_test( "//pkg/sentry/kernel/auth", "//pkg/test/testutil", "//pkg/urpc", - "//runsc/boot", + "//runsc/config", "//runsc/container", "//runsc/specutils", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", ], ) diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go index 01204ab4d..cd419e1aa 100644 --- a/runsc/cmd/boot.go +++ b/runsc/cmd/boot.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" ) @@ -54,10 +55,6 @@ type Boot struct { // provided in that order. stdioFDs intFlags - // console is set to true if the sandbox should allow terminal ioctl(2) - // syscalls. - console bool - // applyCaps determines if capabilities defined in the spec should be applied // to the process. applyCaps bool @@ -115,7 +112,6 @@ func (b *Boot) SetFlags(f *flag.FlagSet) { f.IntVar(&b.deviceFD, "device-fd", -1, "FD for the platform device file") f.Var(&b.ioFDs, "io-fds", "list of FDs to connect 9P clients. They must follow this order: root first, then mounts as defined in the spec") f.Var(&b.stdioFDs, "stdio-fds", "list of FDs containing sandbox stdin, stdout, and stderr in that order") - f.BoolVar(&b.console, "console", false, "set to true if the sandbox should allow terminal ioctl(2) syscalls") f.BoolVar(&b.applyCaps, "apply-caps", false, "if true, apply capabilities defined in the spec to the process") f.BoolVar(&b.setUpRoot, "setup-root", false, "if true, set up an empty root for the process") f.BoolVar(&b.pidns, "pidns", false, "if true, the sandbox is in its own PID namespace") @@ -138,7 +134,7 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Ensure that if there is a panic, all goroutine stacks are printed. debug.SetTraceback("system") - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if b.attached { // Ensure this process is killed after parent process terminates when @@ -172,7 +168,7 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Get the spec from the specFD. specFile := os.NewFile(uintptr(b.specFD), "spec file") defer specFile.Close() - spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile) + spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile, conf) if err != nil { Fatalf("reading spec: %v", err) } @@ -229,7 +225,6 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Device: os.NewFile(uintptr(b.deviceFD), "platform device"), GoferFDs: b.ioFDs.GetArray(), StdioFDs: b.stdioFDs.GetArray(), - Console: b.console, NumCPU: b.cpuNum, TotalMem: b.totalMem, UserLogFD: b.userLogFD, diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go index a84067112..e13a94486 100644 --- a/runsc/cmd/capability_test.go +++ b/runsc/cmd/capability_test.go @@ -24,7 +24,7 @@ import ( "github.com/syndtr/gocapability/capability" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/specutils" ) @@ -88,7 +88,7 @@ func TestCapabilities(t *testing.T) { conf := testutil.TestConfig(t) // Use --network=host to make sandbox use spec's capabilities. - conf.Network = boot.NetworkHost + conf.Network = config.NetworkHost _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go index 8a29e521e..8fe0c427a 100644 --- a/runsc/cmd/checkpoint.go +++ b/runsc/cmd/checkpoint.go @@ -22,7 +22,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -72,7 +72,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) cont, err := container.Load(conf.RootDir, id) @@ -118,7 +118,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa Fatalf("setting bundleDir") } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { Fatalf("reading spec: %v", err) } diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index 910e97577..e76f7ba1d 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -81,7 +81,7 @@ func (c *Create) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if conf.Rootless { return Errorf("Rootless mode not supported with %q", c.Name()) @@ -91,7 +91,7 @@ func (c *Create) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} if bundleDir == "" { bundleDir = getwdOrDie() } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { return Errorf("reading spec: %v", err) } diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index b5de2588b..132198222 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -25,27 +25,26 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) // Debug implements subcommands.Command for the "debug" command. type Debug struct { - pid int - stacks bool - signal int - profileHeap string - profileCPU string - profileGoroutine string - profileBlock string - profileMutex string - trace string - strace string - logLevel string - logPackets string - duration time.Duration - ps bool + pid int + stacks bool + signal int + profileHeap string + profileCPU string + profileBlock string + profileMutex string + trace string + strace string + logLevel string + logPackets string + duration time.Duration + ps bool } // Name implements subcommands.Command. @@ -69,7 +68,6 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log") f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.") f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.") - f.StringVar(&d.profileGoroutine, "profile-goroutine", "", "writes goroutine profile to the given file.") f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.") f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.") f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles") @@ -84,7 +82,7 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { // Execute implements subcommands.Command.Execute. func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { var c *container.Container - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if d.pid == 0 { // No pid, container ID must have been provided. @@ -153,18 +151,6 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } log.Infof("Heap profile written to %q", d.profileHeap) } - if d.profileGoroutine != "" { - f, err := os.Create(d.profileGoroutine) - if err != nil { - return Errorf(err.Error()) - } - defer f.Close() - - if err := c.Sandbox.GoroutineProfile(f); err != nil { - return Errorf(err.Error()) - } - log.Infof("Goroutine profile written to %q", d.profileGoroutine) - } if d.profileBlock != "" { f, err := os.Create(d.profileBlock) if err != nil { diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go index 0e4863f50..4e49deff8 100644 --- a/runsc/cmd/delete.go +++ b/runsc/cmd/delete.go @@ -21,7 +21,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -59,14 +59,14 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} return subcommands.ExitUsageError } - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if err := d.execute(f.Args(), conf); err != nil { Fatalf("%v", err) } return subcommands.ExitSuccess } -func (d *Delete) execute(ids []string, conf *boot.Config) error { +func (d *Delete) execute(ids []string, conf *config.Config) error { for _, id := range ids { c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/delete_test.go b/runsc/cmd/delete_test.go index cb59516a3..e2d994a05 100644 --- a/runsc/cmd/delete_test.go +++ b/runsc/cmd/delete_test.go @@ -18,7 +18,7 @@ import ( "io/ioutil" "testing" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" ) func TestNotFound(t *testing.T) { @@ -27,7 +27,7 @@ func TestNotFound(t *testing.T) { if err != nil { t.Fatalf("error creating dir: %v", err) } - conf := &boot.Config{RootDir: dir} + conf := &config.Config{RootDir: dir} d := Delete{} if err := d.execute(ids, conf); err == nil { diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index 7d1310c96..d1f2e9e6d 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -30,7 +30,7 @@ import ( "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -82,7 +82,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su return subcommands.ExitUsageError } - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) if conf.Rootless { @@ -125,7 +125,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su specutils.LogSpec(spec) cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) - if conf.Network == boot.NetworkNone { + if conf.Network == config.NetworkNone { netns := specs.LinuxNamespace{ Type: specs.NetworkNamespace, } @@ -135,9 +135,9 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su spec.Linux = &specs.Linux{Namespaces: []specs.LinuxNamespace{netns}} } else if conf.Rootless { - if conf.Network == boot.NetworkSandbox { + if conf.Network == config.NetworkSandbox { c.notifyUser("*** Warning: using host network due to --rootless ***") - conf.Network = boot.NetworkHost + conf.Network = config.NetworkHost } } else { diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go index 51f6a98ed..25fe2cf1c 100644 --- a/runsc/cmd/events.go +++ b/runsc/cmd/events.go @@ -22,7 +22,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -72,7 +72,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index d9a94903e..775ed4b43 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -33,7 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/urpc" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/console" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" @@ -105,7 +105,7 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) { // Execute implements subcommands.Command.Execute. It starts a process in an // already created container. func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) e, id, err := ex.parseArgs(f, conf.EnableRaw) if err != nil { Fatalf("parsing process spec: %v", err) @@ -220,7 +220,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi cmd.Stderr = os.Stderr // If the console control socket file is provided, then create a new - // pty master/slave pair and set the TTY on the sandbox process. + // pty master/replica pair and set the TTY on the sandbox process. if ex.consoleSocket != "" { // Create a new TTY pair and send the master on the provided socket. tty, err := console.NewWithSocket(ex.consoleSocket) @@ -229,7 +229,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi } defer tty.Close() - // Set stdio to the new TTY slave. + // Set stdio to the new TTY replica. cmd.Stdin = tty cmd.Stdout = tty cmd.Stderr = tty diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 3966e2d21..371fcc0ae 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -30,7 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/fsgofer" "gvisor.dev/gvisor/runsc/fsgofer/filter" @@ -62,9 +62,8 @@ type Gofer struct { applyCaps bool setUpRoot bool - panicOnWrite bool - specFD int - mountsFD int + specFD int + mountsFD int } // Name implements subcommands.Command. @@ -87,7 +86,6 @@ func (g *Gofer) SetFlags(f *flag.FlagSet) { f.StringVar(&g.bundleDir, "bundle", "", "path to the root of the bundle directory, defaults to the current directory") f.Var(&g.ioFDs, "io-fds", "list of FDs to connect 9P servers. They must follow this order: root first, then mounts as defined in the spec") f.BoolVar(&g.applyCaps, "apply-caps", true, "if true, apply capabilities to restrict what the Gofer process can do") - f.BoolVar(&g.panicOnWrite, "panic-on-write", false, "if true, panics on attempts to write to RO mounts. RW mounts are unnaffected") f.BoolVar(&g.setUpRoot, "setup-root", true, "if true, set up an empty root for the process") f.IntVar(&g.specFD, "spec-fd", -1, "required fd with the container spec") f.IntVar(&g.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to write list of mounts after they have been resolved (direct paths, no symlinks).") @@ -100,15 +98,15 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } + conf := args[0].(*config.Config) + specFile := os.NewFile(uintptr(g.specFD), "spec file") defer specFile.Close() - spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile) + spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile, conf) if err != nil { Fatalf("reading spec: %v", err) } - conf := args[0].(*boot.Config) - if g.setUpRoot { if err := setupRootFS(spec, conf); err != nil { Fatalf("Error setting up root FS: %v", err) @@ -168,8 +166,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Start with root mount, then add any other additional mount as needed. ats := make([]p9.Attacher, 0, len(spec.Mounts)+1) ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{ - ROMount: spec.Root.Readonly || conf.Overlay, - PanicOnWrite: g.panicOnWrite, + ROMount: spec.Root.Readonly || conf.Overlay, }) if err != nil { Fatalf("creating attach point: %v", err) @@ -181,9 +178,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) for _, m := range spec.Mounts { if specutils.Is9PMount(m) { cfg := fsgofer.Config{ - ROMount: isReadonlyMount(m.Options) || conf.Overlay, - PanicOnWrite: g.panicOnWrite, - HostUDS: conf.FSGoferHostUDS, + ROMount: isReadonlyMount(m.Options) || conf.Overlay, + HostUDS: conf.FSGoferHostUDS, } ap, err := fsgofer.NewAttachPoint(m.Destination, cfg) if err != nil { @@ -263,7 +259,7 @@ func isReadonlyMount(opts []string) bool { return false } -func setupRootFS(spec *specs.Spec, conf *boot.Config) error { +func setupRootFS(spec *specs.Spec, conf *config.Config) error { // Convert all shared mounts into slaves to be sure that nothing will be // propagated outside of our namespace. if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil { @@ -316,6 +312,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error { if err != nil { return fmt.Errorf("resolving symlinks to %q: %v", spec.Process.Cwd, err) } + log.Infof("Create working directory %q if needed", spec.Process.Cwd) if err := os.MkdirAll(dst, 0755); err != nil { return fmt.Errorf("creating working directory %q: %v", spec.Process.Cwd, err) } @@ -346,7 +343,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error { // setupMounts binds mount all mounts specified in the spec in their correct // location inside root. It will resolve relative paths and symlinks. It also // creates directories as needed. -func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error { +func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { for _, m := range mounts { if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { continue @@ -385,7 +382,7 @@ func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error { // Otherwise, it may follow symlinks to locations that would be overwritten // with another mount point and return the wrong location. In short, make sure // setupMounts() has been called before. -func resolveMounts(conf *boot.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { +func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { cleanMounts := make([]specs.Mount, 0, len(mounts)) for _, m := range mounts { if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { @@ -467,7 +464,7 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro } // adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs. -func adjustMountOptions(conf *boot.Config, path string, opts []string) ([]string, error) { +func adjustMountOptions(conf *config.Config, path string, opts []string) ([]string, error) { rv := make([]string, len(opts)) copy(rv, opts) diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go index 8282ea0e0..04eee99b2 100644 --- a/runsc/cmd/kill.go +++ b/runsc/cmd/kill.go @@ -23,7 +23,7 @@ import ( "github.com/google/subcommands" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -63,7 +63,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if k.pid != 0 && k.all { Fatalf("it is invalid to specify both --all and --pid") diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go index d8d906fe3..f92d6fef9 100644 --- a/runsc/cmd/list.go +++ b/runsc/cmd/list.go @@ -24,7 +24,7 @@ import ( "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -63,7 +63,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) ids, err := container.List(conf.RootDir) if err != nil { Fatalf("%v", err) diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go index 6f95a9837..0eb1402ed 100644 --- a/runsc/cmd/pause.go +++ b/runsc/cmd/pause.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -53,7 +53,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) cont, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go index 7fb8041af..bc58c928f 100644 --- a/runsc/cmd/ps.go +++ b/runsc/cmd/ps.go @@ -20,7 +20,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -58,7 +58,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go index 72584b326..096ec814c 100644 --- a/runsc/cmd/restore.go +++ b/runsc/cmd/restore.go @@ -20,7 +20,7 @@ import ( "syscall" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -77,7 +77,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{ } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) if conf.Rootless { @@ -88,7 +88,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{ if bundleDir == "" { bundleDir = getwdOrDie() } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { return Errorf("reading spec: %v", err) } diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go index 61a55a554..f24823f99 100644 --- a/runsc/cmd/resume.go +++ b/runsc/cmd/resume.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -54,7 +54,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) cont, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go index cf41581ad..c48cbe4cd 100644 --- a/runsc/cmd/run.go +++ b/runsc/cmd/run.go @@ -19,7 +19,7 @@ import ( "syscall" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -64,7 +64,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) if conf.Rootless { @@ -75,7 +75,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s if bundleDir == "" { bundleDir = getwdOrDie() } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { return Errorf("reading spec: %v", err) } diff --git a/runsc/cmd/spec.go b/runsc/cmd/spec.go index a2b0a4b14..55194e641 100644 --- a/runsc/cmd/spec.go +++ b/runsc/cmd/spec.go @@ -16,124 +16,122 @@ package cmd import ( "context" - "fmt" - "io/ioutil" + "encoding/json" + "io" "os" "path/filepath" "github.com/google/subcommands" + specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/runsc/flag" ) -func genSpec(cwd string) []byte { - var template = fmt.Sprintf(`{ - "ociVersion": "1.0.0", - "process": { - "terminal": true, - "user": { - "uid": 0, - "gid": 0 - }, - "args": [ - "sh" - ], - "env": [ - "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", - "TERM=xterm" - ], - "cwd": "%s", - "capabilities": { - "bounding": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "effective": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "inheritable": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "permitted": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "ambient": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ] - }, - "rlimits": [ - { - "type": "RLIMIT_NOFILE", - "hard": 1024, - "soft": 1024 - } - ] - }, - "root": { - "path": "rootfs", - "readonly": true - }, - "hostname": "runsc", - "mounts": [ - { - "destination": "/proc", - "type": "proc", - "source": "proc" +func writeSpec(w io.Writer, cwd string, netns string, args []string) error { + spec := &specs.Spec{ + Version: "1.0.0", + Process: &specs.Process{ + Terminal: true, + User: specs.User{ + UID: 0, + GID: 0, + }, + Args: args, + Env: []string{ + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "TERM=xterm", + }, + Cwd: cwd, + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + Effective: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + Inheritable: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + Permitted: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + // TODO(gvisor.dev/issue/3166): support ambient capabilities + }, + Rlimits: []specs.POSIXRlimit{ + { + Type: "RLIMIT_NOFILE", + Hard: 1024, + Soft: 1024, + }, + }, }, - { - "destination": "/dev", - "type": "tmpfs", - "source": "tmpfs", - "options": [] + Root: &specs.Root{ + Path: "rootfs", + Readonly: true, }, - { - "destination": "/sys", - "type": "sysfs", - "source": "sysfs", - "options": [ - "nosuid", - "noexec", - "nodev", - "ro" - ] - } - ], - "linux": { - "namespaces": [ + Hostname: "runsc", + Mounts: []specs.Mount{ { - "type": "pid" + Destination: "/proc", + Type: "proc", + Source: "proc", }, { - "type": "network" + Destination: "/dev", + Type: "tmpfs", + Source: "tmpfs", }, { - "type": "ipc" + Destination: "/sys", + Type: "sysfs", + Source: "sysfs", + Options: []string{ + "nosuid", + "noexec", + "nodev", + "ro", + }, }, - { - "type": "uts" + }, + Linux: &specs.Linux{ + Namespaces: []specs.LinuxNamespace{ + { + Type: "pid", + }, + { + Type: "network", + Path: netns, + }, + { + Type: "ipc", + }, + { + Type: "uts", + }, + { + Type: "mount", + }, }, - { - "type": "mount" - } - ] + }, } -}`, cwd) - return []byte(template) + e := json.NewEncoder(w) + e.SetIndent("", " ") + return e.Encode(spec) } // Spec implements subcommands.Command for the "spec" command. type Spec struct { bundle string cwd string + netns string } // Name implements subcommands.Command.Name. @@ -148,21 +146,26 @@ func (*Spec) Synopsis() string { // Usage implements subcommands.Command.Usage. func (*Spec) Usage() string { - return `spec [options] - create a new OCI bundle specification file. + return `spec [options] [-- args...] - create a new OCI bundle specification file. + +The spec command creates a new specification file (config.json) for a new OCI +bundle. -The spec command creates a new specification file (config.json) for a new OCI bundle. +The specification file is a starter file that runs the command specified by +'args' in the container. If 'args' is not specified the default is to run the +'sh' program. -The specification file is a starter file that runs the "sh" command in the container. You -should edit the file to suit your needs. You can find out more about the format of the -specification file by visiting the OCI runtime spec repository: +While a number of flags are provided to change values in the specification, you +can examine the file and edit it to suit your needs after this command runs. +You can find out more about the format of the specification file by visiting +the OCI runtime spec repository: https://github.com/opencontainers/runtime-spec/ EXAMPLE: $ mkdir -p bundle/rootfs $ cd bundle - $ runsc spec + $ runsc spec -- /hello $ docker export $(docker create hello-world) | tar -xf - -C rootfs - $ sed -i 's;"sh";"/hello";' config.json $ sudo runsc run hello ` @@ -173,18 +176,29 @@ func (s *Spec) SetFlags(f *flag.FlagSet) { f.StringVar(&s.bundle, "bundle", ".", "path to the root of the OCI bundle") f.StringVar(&s.cwd, "cwd", "/", "working directory that will be set for the executable, "+ "this value MUST be an absolute path") + f.StringVar(&s.netns, "netns", "", "network namespace path") } // Execute implements subcommands.Command.Execute. func (s *Spec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + // Grab the arguments. + containerArgs := f.Args() + if len(containerArgs) == 0 { + containerArgs = []string{"sh"} + } + confPath := filepath.Join(s.bundle, "config.json") if _, err := os.Stat(confPath); !os.IsNotExist(err) { Fatalf("file %q already exists", confPath) } - var spec = genSpec(s.cwd) + configFile, err := os.OpenFile(confPath, os.O_WRONLY|os.O_CREATE, 0664) + if err != nil { + Fatalf("opening file %q: %v", confPath, err) + } - if err := ioutil.WriteFile(confPath, spec, 0664); err != nil { + err = writeSpec(configFile, s.cwd, s.netns, containerArgs) + if err != nil { Fatalf("writing to %q: %v", confPath, err) } diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index 0205fd9f7..88991b521 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -52,7 +52,7 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go index cf2413deb..2bd2ab9f8 100644 --- a/runsc/cmd/state.go +++ b/runsc/cmd/state.go @@ -21,7 +21,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -55,7 +55,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go index 29c0a15f0..28d0642ed 100644 --- a/runsc/cmd/wait.go +++ b/runsc/cmd/wait.go @@ -21,7 +21,7 @@ import ( "syscall" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -70,7 +70,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/config/BUILD b/runsc/config/BUILD new file mode 100644 index 000000000..b1672bb9d --- /dev/null +++ b/runsc/config/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "config", + srcs = [ + "config.go", + "flags.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/refs", + "//pkg/sentry/watchdog", + "//pkg/sync", + "//runsc/flag", + ], +) + +go_test( + name = "config_test", + size = "small", + srcs = [ + "config_test.go", + ], + library = ":config", + deps = ["//runsc/flag"], +) diff --git a/runsc/boot/config.go b/runsc/config/config.go index bb01b8fb5..f30f79f68 100644 --- a/runsc/boot/config.go +++ b/runsc/config/config.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,220 +12,112 @@ // See the License for the specific language governing permissions and // limitations under the License. -package boot +// Package config provides basic infrastructure to set configuration settings +// for runsc. The configuration is set by flags to the command line. They can +// also propagate to a different process using the same flags. +package config import ( "fmt" - "strconv" - "strings" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/watchdog" ) -// FileAccessType tells how the filesystem is accessed. -type FileAccessType int - -const ( - // FileAccessShared sends IO requests to a Gofer process that validates the - // requests and forwards them to the host. - FileAccessShared FileAccessType = iota - - // FileAccessExclusive is the same as FileAccessShared, but enables - // extra caching for improved performance. It should only be used if - // the sandbox has exclusive access to the filesystem. - FileAccessExclusive -) - -// MakeFileAccessType converts type from string. -func MakeFileAccessType(s string) (FileAccessType, error) { - switch s { - case "shared": - return FileAccessShared, nil - case "exclusive": - return FileAccessExclusive, nil - default: - return 0, fmt.Errorf("invalid file access type %q", s) - } -} - -func (f FileAccessType) String() string { - switch f { - case FileAccessShared: - return "shared" - case FileAccessExclusive: - return "exclusive" - default: - return fmt.Sprintf("unknown(%d)", f) - } -} - -// NetworkType tells which network stack to use. -type NetworkType int - -const ( - // NetworkSandbox uses internal network stack, isolated from the host. - NetworkSandbox NetworkType = iota - - // NetworkHost redirects network related syscalls to the host network. - NetworkHost - - // NetworkNone sets up just loopback using netstack. - NetworkNone -) - -// MakeNetworkType converts type from string. -func MakeNetworkType(s string) (NetworkType, error) { - switch s { - case "sandbox": - return NetworkSandbox, nil - case "host": - return NetworkHost, nil - case "none": - return NetworkNone, nil - default: - return 0, fmt.Errorf("invalid network type %q", s) - } -} - -func (n NetworkType) String() string { - switch n { - case NetworkSandbox: - return "sandbox" - case NetworkHost: - return "host" - case NetworkNone: - return "none" - default: - return fmt.Sprintf("unknown(%d)", n) - } -} - -// MakeWatchdogAction converts type from string. -func MakeWatchdogAction(s string) (watchdog.Action, error) { - switch strings.ToLower(s) { - case "log", "logwarning": - return watchdog.LogWarning, nil - case "panic": - return watchdog.Panic, nil - default: - return 0, fmt.Errorf("invalid watchdog action %q", s) - } -} - -// MakeRefsLeakMode converts type from string. -func MakeRefsLeakMode(s string) (refs.LeakMode, error) { - switch strings.ToLower(s) { - case "disabled": - return refs.NoLeakChecking, nil - case "log-names": - return refs.LeaksLogWarning, nil - case "log-traces": - return refs.LeaksLogTraces, nil - default: - return 0, fmt.Errorf("invalid refs leakmode %q", s) - } -} - -func refsLeakModeToString(mode refs.LeakMode) string { - switch mode { - // If not set, default it to disabled. - case refs.UninitializedLeakChecking, refs.NoLeakChecking: - return "disabled" - case refs.LeaksLogWarning: - return "log-names" - case refs.LeaksLogTraces: - return "log-traces" - default: - panic(fmt.Sprintf("Invalid leakmode: %d", mode)) - } -} - // Config holds configuration that is not part of the runtime spec. +// +// Follow these steps to add a new flag: +// 1. Create a new field in Config. +// 2. Add a field tag with the flag name +// 3. Register a new flag in flags.go, with name and description +// 4. Add any necessary validation into validate() +// 5. If adding an enum, follow the same pattern as FileAccessType +// type Config struct { // RootDir is the runtime root directory. - RootDir string + RootDir string `flag:"root"` // Debug indicates that debug logging should be enabled. - Debug bool + Debug bool `flag:"debug"` // LogFilename is the filename to log to, if not empty. - LogFilename string + LogFilename string `flag:"log"` // LogFormat is the log format. - LogFormat string + LogFormat string `flag:"log-format"` // DebugLog is the path to log debug information to, if not empty. - DebugLog string + DebugLog string `flag:"debug-log"` // PanicLog is the path to log GO's runtime messages, if not empty. - PanicLog string + PanicLog string `flag:"panic-log"` // DebugLogFormat is the log format for debug. - DebugLogFormat string + DebugLogFormat string `flag:"debug-log-format"` // FileAccess indicates how the filesystem is accessed. - FileAccess FileAccessType + FileAccess FileAccessType `flag:"file-access"` // Overlay is whether to wrap the root filesystem in an overlay. - Overlay bool + Overlay bool `flag:"overlay"` // FSGoferHostUDS enables the gofer to mount a host UDS. - FSGoferHostUDS bool + FSGoferHostUDS bool `flag:"fsgofer-host-uds"` // Network indicates what type of network to use. - Network NetworkType + Network NetworkType `flag:"network"` // EnableRaw indicates whether raw sockets should be enabled. Raw // sockets are disabled by stripping CAP_NET_RAW from the list of // capabilities. - EnableRaw bool + EnableRaw bool `flag:"net-raw"` // HardwareGSO indicates that hardware segmentation offload is enabled. - HardwareGSO bool + HardwareGSO bool `flag:"gso"` // SoftwareGSO indicates that software segmentation offload is enabled. - SoftwareGSO bool + SoftwareGSO bool `flag:"software-gso"` // TXChecksumOffload indicates that TX Checksum Offload is enabled. - TXChecksumOffload bool + TXChecksumOffload bool `flag:"tx-checksum-offload"` // RXChecksumOffload indicates that RX Checksum Offload is enabled. - RXChecksumOffload bool + RXChecksumOffload bool `flag:"rx-checksum-offload"` // QDisc indicates the type of queuening discipline to use by default // for non-loopback interfaces. - QDisc QueueingDiscipline + QDisc QueueingDiscipline `flag:"qdisc"` // LogPackets indicates that all network packets should be logged. - LogPackets bool + LogPackets bool `flag:"log-packets"` // Platform is the platform to run on. - Platform string + Platform string `flag:"platform"` // Strace indicates that strace should be enabled. - Strace bool + Strace bool `flag:"strace"` - // StraceSyscalls is the set of syscalls to trace. If StraceEnable is - // true and this list is empty, then all syscalls will be traced. - StraceSyscalls []string + // StraceSyscalls is the set of syscalls to trace (comma-separated values). + // If StraceEnable is true and this string is empty, then all syscalls will + // be traced. + StraceSyscalls string `flag:"strace-syscalls"` // StraceLogSize is the max size of data blobs to display. - StraceLogSize uint + StraceLogSize uint `flag:"strace-log-size"` // DisableSeccomp indicates whether seccomp syscall filters should be // disabled. Pardon the double negation, but default to enabled is important. DisableSeccomp bool // WatchdogAction sets what action the watchdog takes when triggered. - WatchdogAction watchdog.Action + WatchdogAction watchdog.Action `flag:"watchdog-action"` // PanicSignal registers signal handling that panics. Usually set to // SIGUSR2(12) to troubleshoot hangs. -1 disables it. - PanicSignal int + PanicSignal int `flag:"panic-signal"` // ProfileEnable is set to prepare the sandbox to be profiled. - ProfileEnable bool + ProfileEnable bool `flag:"profile"` // RestoreFile is the path to the saved container image RestoreFile string @@ -233,97 +125,215 @@ type Config struct { // NumNetworkChannels controls the number of AF_PACKET sockets that map // to the same underlying network device. This allows netstack to better // scale for high throughput use cases. - NumNetworkChannels int + NumNetworkChannels int `flag:"num-network-channels"` // Rootless allows the sandbox to be started with a user that is not root. // Defense is depth measures are weaker with rootless. Specifically, the // sandbox and Gofer process run as root inside a user namespace with root // mapped to the caller's user. - Rootless bool + Rootless bool `flag:"rootless"` // AlsoLogToStderr allows to send log messages to stderr. - AlsoLogToStderr bool + AlsoLogToStderr bool `flag:"alsologtostderr"` // ReferenceLeakMode sets reference leak check mode - ReferenceLeakMode refs.LeakMode + ReferenceLeak refs.LeakMode `flag:"ref-leak-mode"` // OverlayfsStaleRead instructs the sandbox to assume that the root mount // is on a Linux overlayfs mount, which does not necessarily preserve // coherence between read-only and subsequent writable file descriptors // representing the "same" file. - OverlayfsStaleRead bool + OverlayfsStaleRead bool `flag:"overlayfs-stale-read"` + + // CPUNumFromQuota sets CPU number count to available CPU quota, using + // least integer value greater than or equal to quota. + // + // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2. + CPUNumFromQuota bool `flag:"cpu-num-from-quota"` + + // Enables VFS2. + VFS2 bool `flag:"vfs2"` + + // Enables FUSE usage. + FUSE bool `flag:"fuse"` + + // Allows overriding of flags in OCI annotations. + AllowFlagOverride bool `flag:"allow-flag-override"` + + // Enables seccomp inside the sandbox. + OCISeccomp bool `flag:"oci-seccomp"` // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in // tests. It allows runsc to start the sandbox process as the current // user, and without chrooting the sandbox process. This can be // necessary in test environments that have limited capabilities. - TestOnlyAllowRunAsCurrentUserWithoutChroot bool + TestOnlyAllowRunAsCurrentUserWithoutChroot bool `flag:"TESTONLY-unsafe-nonroot"` // TestOnlyTestNameEnv should only be used in tests. It looks up for the // test name in the container environment variables and adds it to the debug // log file name. This is done to help identify the log with the test when // multiple tests are run in parallel, since there is no way to pass // parameters to the runtime from docker. - TestOnlyTestNameEnv string + TestOnlyTestNameEnv string `flag:"TESTONLY-test-name-env"` +} - // CPUNumFromQuota sets CPU number count to available CPU quota, using - // least integer value greater than or equal to quota. - // - // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2. - CPUNumFromQuota bool +func (c *Config) validate() error { + if c.FileAccess == FileAccessShared && c.Overlay { + return fmt.Errorf("overlay flag is incompatible with shared file access") + } + if c.NumNetworkChannels <= 0 { + return fmt.Errorf("num_network_channels must be > 0, got: %d", c.NumNetworkChannels) + } + return nil +} - // Enables VFS2 (not plumbled through yet). - VFS2 bool +// FileAccessType tells how the filesystem is accessed. +type FileAccessType int + +const ( + // FileAccessExclusive is the same as FileAccessShared, but enables + // extra caching for improved performance. It should only be used if + // the sandbox has exclusive access to the filesystem. + FileAccessExclusive FileAccessType = iota + + // FileAccessShared sends IO requests to a Gofer process that validates the + // requests and forwards them to the host. + FileAccessShared +) + +func fileAccessTypePtr(v FileAccessType) *FileAccessType { + return &v } -// ToFlags returns a slice of flags that correspond to the given Config. -func (c *Config) ToFlags() []string { - f := []string{ - "--root=" + c.RootDir, - "--debug=" + strconv.FormatBool(c.Debug), - "--log=" + c.LogFilename, - "--log-format=" + c.LogFormat, - "--debug-log=" + c.DebugLog, - "--panic-log=" + c.PanicLog, - "--debug-log-format=" + c.DebugLogFormat, - "--file-access=" + c.FileAccess.String(), - "--overlay=" + strconv.FormatBool(c.Overlay), - "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS), - "--network=" + c.Network.String(), - "--log-packets=" + strconv.FormatBool(c.LogPackets), - "--platform=" + c.Platform, - "--strace=" + strconv.FormatBool(c.Strace), - "--strace-syscalls=" + strings.Join(c.StraceSyscalls, ","), - "--strace-log-size=" + strconv.Itoa(int(c.StraceLogSize)), - "--watchdog-action=" + c.WatchdogAction.String(), - "--panic-signal=" + strconv.Itoa(c.PanicSignal), - "--profile=" + strconv.FormatBool(c.ProfileEnable), - "--net-raw=" + strconv.FormatBool(c.EnableRaw), - "--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels), - "--rootless=" + strconv.FormatBool(c.Rootless), - "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr), - "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode), - "--gso=" + strconv.FormatBool(c.HardwareGSO), - "--software-gso=" + strconv.FormatBool(c.SoftwareGSO), - "--rx-checksum-offload=" + strconv.FormatBool(c.RXChecksumOffload), - "--tx-checksum-offload=" + strconv.FormatBool(c.TXChecksumOffload), - "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead), - "--qdisc=" + c.QDisc.String(), +// Set implements flag.Value. +func (f *FileAccessType) Set(v string) error { + switch v { + case "shared": + *f = FileAccessShared + case "exclusive": + *f = FileAccessExclusive + default: + return fmt.Errorf("invalid file access type %q", v) } - if c.CPUNumFromQuota { - f = append(f, "--cpu-num-from-quota") + return nil +} + +// Get implements flag.Value. +func (f *FileAccessType) Get() interface{} { + return *f +} + +// String implements flag.Value. +func (f *FileAccessType) String() string { + switch *f { + case FileAccessShared: + return "shared" + case FileAccessExclusive: + return "exclusive" } - // Only include these if set since it is never to be used by users. - if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { - f = append(f, "--TESTONLY-unsafe-nonroot=true") + panic(fmt.Sprintf("Invalid file access type %v", *f)) +} + +// NetworkType tells which network stack to use. +type NetworkType int + +const ( + // NetworkSandbox uses internal network stack, isolated from the host. + NetworkSandbox NetworkType = iota + + // NetworkHost redirects network related syscalls to the host network. + NetworkHost + + // NetworkNone sets up just loopback using netstack. + NetworkNone +) + +func networkTypePtr(v NetworkType) *NetworkType { + return &v +} + +// Set implements flag.Value. +func (n *NetworkType) Set(v string) error { + switch v { + case "sandbox": + *n = NetworkSandbox + case "host": + *n = NetworkHost + case "none": + *n = NetworkNone + default: + return fmt.Errorf("invalid network type %q", v) + } + return nil +} + +// Get implements flag.Value. +func (n *NetworkType) Get() interface{} { + return *n +} + +// String implements flag.Value. +func (n *NetworkType) String() string { + switch *n { + case NetworkSandbox: + return "sandbox" + case NetworkHost: + return "host" + case NetworkNone: + return "none" } - if len(c.TestOnlyTestNameEnv) != 0 { - f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv) + panic(fmt.Sprintf("Invalid network type %v", *n)) +} + +// QueueingDiscipline is used to specify the kind of Queueing Discipline to +// apply for a give FDBasedLink. +type QueueingDiscipline int + +const ( + // QDiscNone disables any queueing for the underlying FD. + QDiscNone QueueingDiscipline = iota + + // QDiscFIFO applies a simple fifo based queue to the underlying FD. + QDiscFIFO +) + +func queueingDisciplinePtr(v QueueingDiscipline) *QueueingDiscipline { + return &v +} + +// Set implements flag.Value. +func (q *QueueingDiscipline) Set(v string) error { + switch v { + case "none": + *q = QDiscNone + case "fifo": + *q = QDiscFIFO + default: + return fmt.Errorf("invalid qdisc %q", v) } + return nil +} + +// Get implements flag.Value. +func (q *QueueingDiscipline) Get() interface{} { + return *q +} - if c.VFS2 { - f = append(f, "--vfs2=true") +// String implements flag.Value. +func (q *QueueingDiscipline) String() string { + switch *q { + case QDiscNone: + return "none" + case QDiscFIFO: + return "fifo" } + panic(fmt.Sprintf("Invalid qdisc %v", *q)) +} + +func leakModePtr(v refs.LeakMode) *refs.LeakMode { + return &v +} - return f +func watchdogActionPtr(v watchdog.Action) *watchdog.Action { + return &v } diff --git a/runsc/config/config_test.go b/runsc/config/config_test.go new file mode 100644 index 000000000..fb162b7eb --- /dev/null +++ b/runsc/config/config_test.go @@ -0,0 +1,272 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "strings" + "testing" + + "gvisor.dev/gvisor/runsc/flag" +) + +func init() { + RegisterFlags() +} + +func TestDefault(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + // "--root" is always set to something different than the default. Reset it + // to make it easier to test that default values do not generate flags. + c.RootDir = "" + + // All defaults doesn't require setting flags. + flags := c.ToFlags() + if len(flags) > 0 { + t.Errorf("default flags not set correctly for: %s", flags) + } +} + +func setDefault(name string) { + fl := flag.CommandLine.Lookup(name) + fl.Value.Set(fl.DefValue) +} + +func TestFromFlags(t *testing.T) { + flag.CommandLine.Lookup("root").Value.Set("some-path") + flag.CommandLine.Lookup("debug").Value.Set("true") + flag.CommandLine.Lookup("num-network-channels").Value.Set("123") + flag.CommandLine.Lookup("network").Value.Set("none") + defer func() { + setDefault("root") + setDefault("debug") + setDefault("num-network-channels") + setDefault("network") + }() + + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + if want := "some-path"; c.RootDir != want { + t.Errorf("RootDir=%v, want: %v", c.RootDir, want) + } + if want := true; c.Debug != want { + t.Errorf("Debug=%v, want: %v", c.Debug, want) + } + if want := 123; c.NumNetworkChannels != want { + t.Errorf("NumNetworkChannels=%v, want: %v", c.NumNetworkChannels, want) + } + if want := NetworkNone; c.Network != want { + t.Errorf("Network=%v, want: %v", c.Network, want) + } +} + +func TestToFlags(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.RootDir = "some-path" + c.Debug = true + c.NumNetworkChannels = 123 + c.Network = NetworkNone + + flags := c.ToFlags() + if len(flags) != 4 { + t.Errorf("wrong number of flags set, want: 4, got: %d: %s", len(flags), flags) + } + t.Logf("Flags: %s", flags) + fm := map[string]string{} + for _, f := range flags { + kv := strings.Split(f, "=") + fm[kv[0]] = kv[1] + } + for name, want := range map[string]string{ + "--root": "some-path", + "--debug": "true", + "--num-network-channels": "123", + "--network": "none", + } { + if got, ok := fm[name]; ok { + if got != want { + t.Errorf("flag %q, want: %q, got: %q", name, want, got) + } + } else { + t.Errorf("flag %q not set", name) + } + } +} + +// TestInvalidFlags checks that enum flags fail when value is not in enum set. +func TestInvalidFlags(t *testing.T) { + for _, tc := range []struct { + name string + error string + }{ + { + name: "file-access", + error: "invalid file access type", + }, + { + name: "network", + error: "invalid network type", + }, + { + name: "qdisc", + error: "invalid qdisc", + }, + { + name: "watchdog-action", + error: "invalid watchdog action", + }, + { + name: "ref-leak-mode", + error: "invalid ref leak mode", + }, + } { + t.Run(tc.name, func(t *testing.T) { + defer setDefault(tc.name) + if err := flag.CommandLine.Lookup(tc.name).Value.Set("invalid"); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("flag.Value.Set(invalid) wrong error reported: %v", err) + } + }) + } +} + +func TestValidationFail(t *testing.T) { + for _, tc := range []struct { + name string + flags map[string]string + error string + }{ + { + name: "shared+overlay", + flags: map[string]string{ + "file-access": "shared", + "overlay": "true", + }, + error: "overlay flag is incompatible", + }, + { + name: "network-channels", + flags: map[string]string{ + "num-network-channels": "-1", + }, + error: "num_network_channels must be > 0", + }, + } { + t.Run(tc.name, func(t *testing.T) { + for name, val := range tc.flags { + defer setDefault(name) + if err := flag.CommandLine.Lookup(name).Value.Set(val); err != nil { + t.Errorf("%s=%q: %v", name, val, err) + } + } + if _, err := NewFromFlags(); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("NewFromFlags() wrong error reported: %v", err) + } + }) + } +} + +func TestOverride(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.AllowFlagOverride = true + + t.Run("string", func(t *testing.T) { + c.RootDir = "foobar" + if err := c.Override("root", "bar"); err != nil { + t.Fatalf("Override(root, bar) failed: %v", err) + } + defer setDefault("root") + if c.RootDir != "bar" { + t.Errorf("Override(root, bar) didn't work: %+v", c) + } + }) + + t.Run("bool", func(t *testing.T) { + c.Debug = true + if err := c.Override("debug", "false"); err != nil { + t.Fatalf("Override(debug, false) failed: %v", err) + } + defer setDefault("debug") + if c.Debug { + t.Errorf("Override(debug, false) didn't work: %+v", c) + } + }) + + t.Run("enum", func(t *testing.T) { + c.FileAccess = FileAccessShared + if err := c.Override("file-access", "exclusive"); err != nil { + t.Fatalf("Override(file-access, exclusive) failed: %v", err) + } + defer setDefault("file-access") + if c.FileAccess != FileAccessExclusive { + t.Errorf("Override(file-access, exclusive) didn't work: %+v", c) + } + }) +} + +func TestOverrideDisabled(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + const errMsg = "flag override disabled" + if err := c.Override("root", "path"); err == nil || !strings.Contains(err.Error(), errMsg) { + t.Errorf("Override() wrong error: %v", err) + } +} + +func TestOverrideError(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.AllowFlagOverride = true + for _, tc := range []struct { + name string + value string + error string + }{ + { + name: "invalid", + value: "valid", + error: `flag "invalid" not found`, + }, + { + name: "debug", + value: "invalid", + error: "error setting flag debug", + }, + { + name: "file-access", + value: "invalid", + error: "invalid file access type", + }, + } { + t.Run(tc.name, func(t *testing.T) { + if err := c.Override(tc.name, tc.value); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("Override(%q, %q) wrong error: %v", tc.name, tc.value, err) + } + }) + } +} diff --git a/runsc/config/flags.go b/runsc/config/flags.go new file mode 100644 index 000000000..a5f25cfa2 --- /dev/null +++ b/runsc/config/flags.go @@ -0,0 +1,205 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strconv" + + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/runsc/flag" +) + +var registration sync.Once + +// This is the set of flags used to populate Config. +func RegisterFlags() { + registration.Do(func() { + // Although these flags are not part of the OCI spec, they are used by + // Docker, and thus should not be changed. + flag.String("root", "", "root directory for storage of container state.") + flag.String("log", "", "file path where internal debug information is written, default is stdout.") + flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") + flag.Bool("debug", false, "enable debug logging.") + + // These flags are unique to runsc, and are used to configure parts of the + // system that are not covered by the runtime spec. + + // Debugging flags. + flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") + flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") + flag.Bool("log-packets", false, "enable network packet logging.") + flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") + flag.Bool("alsologtostderr", false, "send log messages to stderr.") + flag.Bool("allow-flag-override", false, "allow OCI annotations (dev.gvisor.flag.<name>) to override flags for debugging.") + + // Debugging flags: strace related + flag.Bool("strace", false, "enable strace.") + flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") + flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") + + // Flags that control sandbox runtime behavior. + flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") + flag.Var(watchdogActionPtr(watchdog.LogWarning), "watchdog-action", "sets what action the watchdog takes when triggered: log (default), panic.") + flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") + flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") + flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") + flag.Var(leakModePtr(refs.NoLeakChecking), "ref-leak-mode", "sets reference leak check mode: disabled (default), log-names, log-traces.") + flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") + flag.Bool("oci-seccomp", false, "Enables loading OCI seccomp filters inside the sandbox.") + + // Flags that control sandbox runtime behavior: FS related. + flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") + flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") + flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") + flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") + flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") + flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") + + // Flags that control sandbox runtime behavior: network related. + flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") + flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") + flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") + flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.") + flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.") + flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.") + flag.Var(queueingDisciplinePtr(QDiscFIFO), "qdisc", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") + flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") + + // Test flags, not to be used outside tests, ever. + flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") + flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + }) +} + +// NewFromFlags creates a new Config with values coming from command line flags. +func NewFromFlags() (*Config, error) { + conf := &Config{} + + obj := reflect.ValueOf(conf).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + name, ok := f.Tag.Lookup("flag") + if !ok { + // No flag set for this field. + continue + } + fl := flag.CommandLine.Lookup(name) + if fl == nil { + panic(fmt.Sprintf("Flag %q not found", name)) + } + x := reflect.ValueOf(flag.Get(fl.Value)) + obj.Field(i).Set(x) + } + + if len(conf.RootDir) == 0 { + // If not set, set default root dir to something (hopefully) user-writeable. + conf.RootDir = "/var/run/runsc" + if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { + conf.RootDir = filepath.Join(runtimeDir, "runsc") + } + } + + if err := conf.validate(); err != nil { + return nil, err + } + return conf, nil +} + +// ToFlags returns a slice of flags that correspond to the given Config. +func (c *Config) ToFlags() []string { + var rv []string + + obj := reflect.ValueOf(c).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + name, ok := f.Tag.Lookup("flag") + if !ok { + // No flag set for this field. + continue + } + val := getVal(obj.Field(i)) + + flag := flag.CommandLine.Lookup(name) + if flag == nil { + panic(fmt.Sprintf("Flag %q not found", name)) + } + if val == flag.DefValue { + continue + } + rv = append(rv, fmt.Sprintf("--%s=%s", flag.Name, val)) + } + return rv +} + +// Override writes a new value to a flag. +func (c *Config) Override(name string, value string) error { + if !c.AllowFlagOverride { + return fmt.Errorf("flag override disabled, use --allow-flag-override to enable it") + } + + obj := reflect.ValueOf(c).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + fieldName, ok := f.Tag.Lookup("flag") + if !ok || fieldName != name { + // Not a flag field, or flag name doesn't match. + continue + } + fl := flag.CommandLine.Lookup(name) + if fl == nil { + // Flag must exist if there is a field match above. + panic(fmt.Sprintf("Flag %q not found", name)) + } + + // Use flag to convert the string value to the underlying flag type, using + // the same rules as the command-line for consistency. + if err := fl.Value.Set(value); err != nil { + return fmt.Errorf("error setting flag %s=%q: %w", name, value, err) + } + x := reflect.ValueOf(flag.Get(fl.Value)) + obj.Field(i).Set(x) + + // Validates the config again to ensure it's left in a consistent state. + return c.validate() + } + return fmt.Errorf("flag %q not found. Cannot set it to %q", name, value) +} + +func getVal(field reflect.Value) string { + if str, ok := field.Addr().Interface().(fmt.Stringer); ok { + return str.String() + } + switch field.Kind() { + case reflect.Bool: + return strconv.FormatBool(field.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(field.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return strconv.FormatUint(field.Uint(), 10) + case reflect.String: + return field.String() + default: + panic("unknown type " + field.Kind().String()) + } +} diff --git a/runsc/console/console.go b/runsc/console/console.go index 64b23639a..dbb88e117 100644 --- a/runsc/console/console.go +++ b/runsc/console/console.go @@ -24,11 +24,11 @@ import ( "golang.org/x/sys/unix" ) -// NewWithSocket creates pty master/slave pair, sends the master FD over the given -// socket, and returns the slave. +// NewWithSocket creates pty master/replica pair, sends the master FD over the given +// socket, and returns the replica. func NewWithSocket(socketPath string) (*os.File, error) { - // Create a new pty master and slave. - ptyMaster, ptySlave, err := pty.Open() + // Create a new pty master and replica. + ptyMaster, ptyReplica, err := pty.Open() if err != nil { return nil, fmt.Errorf("opening pty: %v", err) } @@ -37,18 +37,18 @@ func NewWithSocket(socketPath string) (*os.File, error) { // Get a connection to the socket path. conn, err := net.Dial("unix", socketPath) if err != nil { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("dialing socket %q: %v", socketPath, err) } defer conn.Close() uc, ok := conn.(*net.UnixConn) if !ok { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("connection is not a UnixConn: %T", conn) } socket, err := uc.File() if err != nil { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("getting file for unix socket %v: %v", uc, err) } defer socket.Close() @@ -56,8 +56,8 @@ func NewWithSocket(socketPath string) (*os.File, error) { // Send the master FD over the connection. msg := unix.UnixRights(int(ptyMaster.Fd())) if err := unix.Sendmsg(int(socket.Fd()), []byte("pty-master"), msg, nil, 0); err != nil { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("sending console over unix socket %q: %v", socketPath, err) } - return ptySlave, nil + return ptyReplica, nil } diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 49cfb0837..c33755482 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -23,11 +23,12 @@ go_library( "//pkg/sync", "//runsc/boot", "//runsc/cgroup", + "//runsc/config", "//runsc/sandbox", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_gofrs_flock//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) @@ -65,10 +66,11 @@ go_test( "//pkg/urpc", "//runsc/boot", "//runsc/boot/platforms", + "//runsc/config", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_kr_pty//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 3813c6b93..4228399b8 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -122,6 +122,7 @@ func TestConsoleSocket(t *testing.T) { for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { spec := testutil.NewSpecWithArgs("true") + spec.Process.Terminal = true _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -184,14 +185,14 @@ func TestJobControlSignalExec(t *testing.T) { t.Fatalf("error starting container: %v", err) } - // Create a pty master/slave. The slave will be passed to the exec + // Create a pty master/replica. The replica will be passed to the exec // process. - ptyMaster, ptySlave, err := pty.Open() + ptyMaster, ptyReplica, err := pty.Open() if err != nil { t.Fatalf("error opening pty: %v", err) } defer ptyMaster.Close() - defer ptySlave.Close() + defer ptyReplica.Close() // Exec bash and attach a terminal. Note that occasionally /bin/sh // may be a different shell or have a different configuration (such @@ -202,9 +203,9 @@ func TestJobControlSignalExec(t *testing.T) { // Don't let bash execute from profile or rc files, otherwise // our PID counts get messed up. Argv: []string{"/bin/bash", "--noprofile", "--norc"}, - // Pass the pty slave as FD 0, 1, and 2. + // Pass the pty replica as FD 0, 1, and 2. FilePayload: urpc.FilePayload{ - Files: []*os.File{ptySlave, ptySlave, ptySlave}, + Files: []*os.File{ptyReplica, ptyReplica, ptyReplica}, }, StdioIsPty: true, } diff --git a/runsc/container/container.go b/runsc/container/container.go index 6d297d0df..63478ba8c 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/sighandling" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/cgroup" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/sandbox" "gvisor.dev/gvisor/runsc/specutils" ) @@ -269,7 +270,7 @@ type Args struct { // New creates the container in a new Sandbox process, unless the metadata // indicates that an existing Sandbox should be used. The caller must call // Destroy() on the container. -func New(conf *boot.Config, args Args) (*Container, error) { +func New(conf *config.Config, args Args) (*Container, error) { log.Debugf("Create container %q in root dir: %s", args.ID, conf.RootDir) if err := validateID(args.ID); err != nil { return nil, err @@ -324,7 +325,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { } } if err := runInCgroup(cg, func() error { - ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir) + ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir, args.Attached) if err != nil { return err } @@ -397,7 +398,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { } // Start starts running the containerized process inside the sandbox. -func (c *Container) Start(conf *boot.Config) error { +func (c *Container) Start(conf *config.Config) error { log.Debugf("Start container %q", c.ID) if err := c.Saver.lock(); err != nil { @@ -427,7 +428,7 @@ func (c *Container) Start(conf *boot.Config) error { // the start (and all their children processes). if err := runInCgroup(c.Sandbox.Cgroup, func() error { // Create the gofer process. - ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir) + ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir, false) if err != nil { return err } @@ -472,7 +473,7 @@ func (c *Container) Start(conf *boot.Config) error { // Restore takes a container and replaces its kernel and file system // to restore a container from its state file. -func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error { +func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile string) error { log.Debugf("Restore container %q", c.ID) if err := c.Saver.lock(); err != nil { return err @@ -499,7 +500,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str } // Run is a helper that calls Create + Start + Wait. -func Run(conf *boot.Config, args Args) (syscall.WaitStatus, error) { +func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { log.Debugf("Run container %q in root dir: %s", args.ID, conf.RootDir) c, err := New(conf, args) if err != nil { @@ -861,7 +862,7 @@ func (c *Container) waitForStopped() error { return backoff.Retry(op, b) } -func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string) ([]*os.File, *os.File, error) { +func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bundleDir string, attached bool) ([]*os.File, *os.File, error) { // Start with the general config flags. args := conf.ToFlags() @@ -901,9 +902,6 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund } args = append(args, "gofer", "--bundle", bundleDir) - if conf.Overlay { - args = append(args, "--panic-on-write=true") - } // Open the spec file to donate to the sandbox. specFile, err := specutils.OpenSpec(bundleDir) @@ -955,6 +953,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund cmd.ExtraFiles = goferEnds cmd.Args[0] = "runsc-gofer" + if attached { + // The gofer is attached to the lifetime of this process, so it + // should synchronously die when this process dies. + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGKILL, + } + } + // Enter new namespaces to isolate from the rest of the system. Don't unshare // cgroup because gofer is added to a cgroup in the caller's namespace. nss := []specs.LinuxNamespace{ diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index cd76645bd..1f8e277cc 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -41,8 +41,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot/platforms" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -250,7 +251,7 @@ func readOutputNum(file string, position int) (int, error) { // run starts the sandbox and waits for it to exit, checking that the // application succeeded. -func run(spec *specs.Spec, conf *boot.Config) error { +func run(spec *specs.Spec, conf *config.Config) error { _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { return fmt.Errorf("error setting up container: %v", err) @@ -289,26 +290,24 @@ var ( ) // configs generates different configurations to run tests. -func configs(t *testing.T, opts ...configOption) map[string]*boot.Config { +func configs(t *testing.T, opts ...configOption) map[string]*config.Config { // Always load the default config. - cs := make(map[string]*boot.Config) + cs := make(map[string]*config.Config) + testutil.TestConfig(t) for _, o := range opts { + c := testutil.TestConfig(t) switch o { case overlay: - c := testutil.TestConfig(t) c.Overlay = true cs["overlay"] = c case ptrace: - c := testutil.TestConfig(t) c.Platform = platforms.Ptrace cs["ptrace"] = c case kvm: - c := testutil.TestConfig(t) c.Platform = platforms.KVM cs["kvm"] = c case nonExclusiveFS: - c := testutil.TestConfig(t) - c.FileAccess = boot.FileAccessShared + c.FileAccess = config.FileAccessShared cs["non-exclusive"] = c default: panic(fmt.Sprintf("unknown config option %v", o)) @@ -317,23 +316,14 @@ func configs(t *testing.T, opts ...configOption) map[string]*boot.Config { return cs } -func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*boot.Config { - vfs1 := configs(t, opts...) - - var optsVFS2 []configOption - for _, opt := range opts { - // TODO(gvisor.dev/issue/1487): Enable overlay tests. - if opt != overlay { - optsVFS2 = append(optsVFS2, opt) - } - } - - for key, value := range configs(t, optsVFS2...) { +// TODO(gvisor.dev/issue/1624): Merge with configs when VFS2 is the default. +func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*config.Config { + all := configs(t, opts...) + for key, value := range configs(t, opts...) { value.VFS2 = true - vfs1[key+"VFS2"] = value + all[key+"VFS2"] = value } - - return vfs1 + return all } // TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle. @@ -512,7 +502,7 @@ func TestExePath(t *testing.T) { t.Fatalf("error making directory: %v", err) } - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { for _, test := range []struct { path string @@ -643,7 +633,9 @@ func TestExec(t *testing.T) { if err != nil { t.Fatalf("error creating temporary directory: %v", err) } - cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100", dir) + // Note that some shells may exec the final command in a sequence as + // an optimization. We avoid this here by adding the exit 0. + cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100 && exit 0", dir) spec := testutil.NewSpecWithArgs("sh", "-c", cmd) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) @@ -835,7 +827,7 @@ func TestExecProcList(t *testing.T) { // TestKillPid verifies that we can signal individual exec'd processes. func TestKillPid(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { app, err := testutil.FindFile("test/cmd/test_app/test_app") if err != nil { @@ -903,13 +895,15 @@ func TestKillPid(t *testing.T) { } } -// TestCheckpointRestore creates a container that continuously writes successive integers -// to a file. To test checkpoint and restore functionality, the container is -// checkpointed and the last number printed to the file is recorded. Then, it is restored in two -// new containers and the first number printed from these containers is checked. Both should -// be the next consecutive number after the last number from the checkpointed container. +// TestCheckpointRestore creates a container that continuously writes successive +// integers to a file. To test checkpoint and restore functionality, the +// container is checkpointed and the last number printed to the file is +// recorded. Then, it is restored in two new containers and the first number +// printed from these containers is checked. Both should be the next consecutive +// number after the last number from the checkpointed container. func TestCheckpointRestore(t *testing.T) { // Skip overlay because test requires writing to host file. + // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test") @@ -1071,6 +1065,7 @@ func TestCheckpointRestore(t *testing.T) { // with filesystem Unix Domain Socket use. func TestUnixDomainSockets(t *testing.T) { // Skip overlay because test requires writing to host file. + // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { // UDS path is limited to 108 chars for compatibility with older systems. @@ -1208,7 +1203,7 @@ func TestUnixDomainSockets(t *testing.T) { // recreated. Then it resumes the container, verify that the file gets created // again. func TestPauseResume(t *testing.T) { - for name, conf := range configs(t, noOverlay...) { + for name, conf := range configsWithVFS2(t, noOverlay...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock") if err != nil { @@ -1468,7 +1463,7 @@ func TestRunNonRoot(t *testing.T) { // TestMountNewDir checks that runsc will create destination directory if it // doesn't exit. func TestMountNewDir(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { root, err := ioutil.TempDir(testutil.TmpDir(), "root") if err != nil { @@ -1488,6 +1483,8 @@ func TestMountNewDir(t *testing.T) { Source: srcDir, Type: "bind", }) + // Extra points for creating the mount with a readonly root. + spec.Root.Readonly = true if err := run(spec, conf); err != nil { t.Fatalf("error running sandbox: %v", err) @@ -1497,17 +1494,17 @@ func TestMountNewDir(t *testing.T) { } func TestReadonlyRoot(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { - spec := testutil.NewSpecWithArgs("/bin/touch", "/foo") + spec := testutil.NewSpecWithArgs("sleep", "100") spec.Root.Readonly = true + _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) } defer cleanup() - // Create, start and wait for the container. args := Args{ ID: testutil.RandomContainerID(), Spec: spec, @@ -1522,12 +1519,82 @@ func TestReadonlyRoot(t *testing.T) { t.Fatalf("error starting container: %v", err) } - ws, err := c.Wait() + // Read mounts to check that root is readonly. + out, ws, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / '") + if err != nil || ws != 0 { + t.Fatalf("exec failed, ws: %v, err: %v", ws, err) + } + t.Logf("root mount: %q", out) + if !strings.Contains(string(out), "(ro)") { + t.Errorf("root not mounted readonly: %q", out) + } + + // Check that file cannot be created. + ws, err = execute(c, "/bin/touch", "/foo") if err != nil { - t.Fatalf("error waiting on container: %v", err) + t.Fatalf("touch file in ro mount: %v", err) } if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { - t.Fatalf("container failed, waitStatus: %v", ws) + t.Fatalf("wrong waitStatus: %v", ws) + } + }) + } +} + +func TestReadonlyMount(t *testing.T) { + for name, conf := range configsWithVFS2(t, all...) { + t.Run(name, func(t *testing.T) { + dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + spec := testutil.NewSpecWithArgs("sleep", "100") + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: dir, + Source: dir, + Type: "bind", + Options: []string{"ro"}, + }) + spec.Root.Readonly = false + + _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) + if err != nil { + t.Fatalf("error setting up container: %v", err) + } + defer cleanup() + + args := Args{ + ID: testutil.RandomContainerID(), + Spec: spec, + BundleDir: bundleDir, + } + c, err := New(conf, args) + if err != nil { + t.Fatalf("error creating container: %v", err) + } + defer c.Destroy() + if err := c.Start(conf); err != nil { + t.Fatalf("error starting container: %v", err) + } + + // Read mounts to check that volume is readonly. + cmd := fmt.Sprintf("mount | grep ' %s '", dir) + out, ws, err := executeCombinedOutput(c, "/bin/sh", "-c", cmd) + if err != nil || ws != 0 { + t.Fatalf("exec failed, ws: %v, err: %v", ws, err) + } + t.Logf("mount: %q", out) + if !strings.Contains(string(out), "(ro)") { + t.Errorf("volume not mounted readonly: %q", out) + } + + // Check that file cannot be created. + ws, err = execute(c, "/bin/touch", path.Join(dir, "file")) + if err != nil { + t.Fatalf("touch file in ro mount: %v", err) + } + if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { + t.Fatalf("wrong WaitStatus: %v", ws) } }) } @@ -1614,54 +1681,6 @@ func TestUIDMap(t *testing.T) { } } -func TestReadonlyMount(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { - t.Run(name, func(t *testing.T) { - dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") - spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file")) - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: dir, - Source: dir, - Type: "bind", - Options: []string{"ro"}, - }) - spec.Root.Readonly = false - - _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer cleanup() - - // Create, start and wait for the container. - args := Args{ - ID: testutil.RandomContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { - t.Fatalf("container failed, waitStatus: %v", ws) - } - }) - } -} - // TestAbbreviatedIDs checks that runsc supports using abbreviated container // IDs in place of full IDs. func TestAbbreviatedIDs(t *testing.T) { @@ -1828,8 +1847,9 @@ func TestUserLog(t *testing.T) { t.Fatal("error finding test_app:", err) } - // sched_rr_get_interval = 148 - not implemented in gvisor. - spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148") + // sched_rr_get_interval - not implemented in gvisor. + num := strconv.Itoa(syscall.SYS_SCHED_RR_GET_INTERVAL) + spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall="+num) conf := testutil.TestConfig(t) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { @@ -2011,7 +2031,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) { } func TestCreateWorkingDir(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create") if err != nil { @@ -2114,27 +2134,19 @@ func TestMountPropagation(t *testing.T) { // Check that mount didn't propagate to private mount. privFile := filepath.Join(priv, "mnt", "file") - execArgs := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "!", "-f", privFile}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { + if ws, err := execute(cont, "/usr/bin/test", "!", "-f", privFile); err != nil || ws != 0 { t.Fatalf("exec: test ! -f %q, ws: %v, err: %v", privFile, ws, err) } // Check that mount propagated to slave mount. slaveFile := filepath.Join(slave, "mnt", "file") - execArgs = &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", slaveFile}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { + if ws, err := execute(cont, "/usr/bin/test", "-f", slaveFile); err != nil || ws != 0 { t.Fatalf("exec: test -f %q, ws: %v, err: %v", privFile, ws, err) } } func TestMountSymlink(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink") if err != nil { @@ -2194,11 +2206,7 @@ func TestMountSymlink(t *testing.T) { // Check that symlink was resolved and mount was created where the symlink // is pointing to. file := path.Join(target, "file") - execArgs := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", file}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { + if ws, err := execute(cont, "/usr/bin/test", "-f", file); err != nil || ws != 0 { t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err) } }) @@ -2324,6 +2332,35 @@ func TestTTYField(t *testing.T) { } } +func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) { + args := &control.ExecArgs{ + Filename: name, + Argv: append([]string{name}, arg...), + } + return cont.executeSync(args) +} + +func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte, syscall.WaitStatus, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, 0, err + } + defer r.Close() + + args := &control.ExecArgs{ + Filename: name, + Argv: append([]string{name}, arg...), + FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, w, w}}, + } + ws, err := cont.executeSync(args) + w.Close() + if err != nil { + return nil, 0, err + } + out, err := ioutil.ReadAll(r) + return out, ws, err +} + // executeSync synchronously executes a new process. func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { pid, err := cont.Execute(args) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index a27a01942..850e80290 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -60,7 +61,7 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { return specs, ids } -func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { +func startContainers(conf *config.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { if len(conf.RootDir) == 0 { panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } @@ -100,19 +101,20 @@ type execDesc struct { c *Container cmd []string want int - desc string + name string } -func execMany(execs []execDesc) error { +func execMany(t *testing.T, execs []execDesc) { for _, exec := range execs { - args := &control.ExecArgs{Argv: exec.cmd} - if ws, err := exec.c.executeSync(args); err != nil { - return fmt.Errorf("error executing %+v: %v", args, err) - } else if ws.ExitStatus() != exec.want { - return fmt.Errorf("%q: exec %q got exit status: %d, want: %d", exec.desc, exec.cmd, ws.ExitStatus(), exec.want) - } + t.Run(exec.name, func(t *testing.T) { + args := &control.ExecArgs{Argv: exec.cmd} + if ws, err := exec.c.executeSync(args); err != nil { + t.Errorf("error executing %+v: %v", args, err) + } else if ws.ExitStatus() != exec.want { + t.Errorf("%q: exec %q got exit status: %d, want: %d", exec.name, exec.cmd, ws.ExitStatus(), exec.want) + } + }) } - return nil } func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { @@ -167,7 +169,7 @@ func TestMultiContainerSanity(t *testing.T) { // TestMultiPIDNS checks that it is possible to run 2 dead-simple // containers in the same sandbox with different pidns. func TestMultiPIDNS(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -212,7 +214,7 @@ func TestMultiPIDNS(t *testing.T) { // TestMultiPIDNSPath checks the pidns path. func TestMultiPIDNSPath(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -478,7 +480,7 @@ func TestMultiContainerMount(t *testing.T) { // TestMultiContainerSignal checks that it is possible to signal individual // containers without killing the entire sandbox. func TestMultiContainerSignal(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -578,7 +580,7 @@ func TestMultiContainerDestroy(t *testing.T) { t.Fatal("error finding test_app:", err) } - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1072,7 +1074,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { // Test that pod shared mounts are properly mounted in 2 containers and that // changes from one container is reflected in the other. func TestMultiContainerSharedMount(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1110,84 +1112,82 @@ func TestMultiContainerSharedMount(t *testing.T) { { c: containers[0], cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", + name: "directory is mounted in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", + name: "directory is mounted in container1", }, { c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - desc: "create file in container0", + cmd: []string{"/bin/touch", file0}, + name: "create file in container0", }, { c: containers[0], cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file appears in container0", + name: "file appears in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file appears in container1", + name: "file appears in container1", }, { c: containers[1], cmd: []string{"/bin/rm", file1}, - desc: "file removed from container1", + name: "remove file from container1", }, { c: containers[0], cmd: []string{"/usr/bin/test", "!", "-f", file0}, - desc: "file removed from container0", + name: "file removed from container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "!", "-f", file1}, - desc: "file removed from container1", + name: "file removed from container1", }, { c: containers[1], cmd: []string{"/bin/mkdir", file1}, - desc: "create directory in container1", + name: "create directory in container1", }, { c: containers[0], cmd: []string{"/usr/bin/test", "-d", file0}, - desc: "dir appears in container0", + name: "dir appears in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-d", file1}, - desc: "dir appears in container1", + name: "dir appears in container1", }, { c: containers[0], cmd: []string{"/bin/rmdir", file0}, - desc: "create directory in container0", + name: "remove directory from container0", }, { c: containers[0], cmd: []string{"/usr/bin/test", "!", "-d", file0}, - desc: "dir removed from container0", + name: "dir removed from container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "!", "-d", file1}, - desc: "dir removed from container1", + name: "dir removed from container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) }) } } // Test that pod mounts are mounted as readonly when requested. func TestMultiContainerSharedMountReadonly(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1225,36 +1225,34 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { { c: containers[0], cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", + name: "directory is mounted in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", + name: "directory is mounted in container1", }, { c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, + cmd: []string{"/bin/touch", file0}, want: 1, - desc: "fails to write to container0", + name: "fails to write to container0", }, { c: containers[1], - cmd: []string{"/usr/bin/touch", file1}, + cmd: []string{"/bin/touch", file1}, want: 1, - desc: "fails to write to container1", + name: "fails to write to container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) }) } } // Test that shared pod mounts continue to work after container is restarted. func TestMultiContainerSharedMountRestart(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1291,23 +1289,21 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { execs := []execDesc{ { c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - desc: "create file in container0", + cmd: []string{"/bin/touch", file0}, + name: "create file in container0", }, { c: containers[0], cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file appears in container0", + name: "file appears in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file appears in container1", + name: "file appears in container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) containers[1].Destroy() @@ -1334,86 +1330,84 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { { c: containers[0], cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file is still in container0", + name: "file is still in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file is still in container1", + name: "file is still in container1", }, { c: containers[1], cmd: []string{"/bin/rm", file1}, - desc: "file removed from container1", + name: "file removed from container1", }, { c: containers[0], cmd: []string{"/usr/bin/test", "!", "-f", file0}, - desc: "file removed from container0", + name: "file removed from container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "!", "-f", file1}, - desc: "file removed from container1", + name: "file removed from container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) }) } } // Test that unsupported pod mounts options are ignored when matching master and -// slave mounts. +// replica mounts. func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { - rootDir, cleanup, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer cleanup() - - conf := testutil.TestConfig(t) - conf.RootDir = rootDir + for name, conf := range configsWithVFS2(t, all...) { + t.Run(name, func(t *testing.T) { + rootDir, cleanup, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer cleanup() + conf.RootDir = rootDir - // Setup the containers. - sleep := []string{"/bin/sleep", "100"} - podSpec, ids := createSpecs(sleep, sleep) - mnt0 := specs.Mount{ - Destination: "/mydir/test", - Source: "/some/dir", - Type: "tmpfs", - Options: []string{"rw", "rbind", "relatime"}, - } - podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) + // Setup the containers. + sleep := []string{"/bin/sleep", "100"} + podSpec, ids := createSpecs(sleep, sleep) + mnt0 := specs.Mount{ + Destination: "/mydir/test", + Source: "/some/dir", + Type: "tmpfs", + Options: []string{"rw", "rbind", "relatime"}, + } + podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) - mnt1 := mnt0 - mnt1.Destination = "/mydir2/test2" - mnt1.Options = []string{"rw", "nosuid"} - podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) + mnt1 := mnt0 + mnt1.Destination = "/mydir2/test2" + mnt1.Options = []string{"rw", "nosuid"} + podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) - createSharedMount(mnt0, "test-mount", podSpec...) + createSharedMount(mnt0, "test-mount", podSpec...) - containers, cleanup, err := startContainers(conf, podSpec, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() - execs := []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) + execs := []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, + name: "directory is mounted in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, + name: "directory is mounted in container1", + }, + } + execMany(t, execs) + }) } } @@ -1523,8 +1517,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { } // Check that container isn't running anymore. - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err == nil { + if _, err := execute(c, "/bin/true"); err == nil { t.Fatalf("Container %q was not stopped after gofer death", c.ID) } @@ -1539,8 +1532,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { if err := waitForProcessList(c, pl); err != nil { t.Errorf("Container %q was affected by another container: %v", c.ID, err) } - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err != nil { + if _, err := execute(c, "/bin/true"); err != nil { t.Fatalf("Container %q was affected by another container: %v", c.ID, err) } } @@ -1562,8 +1554,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Check that entire sandbox isn't running anymore. for _, c := range containers { - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err == nil { + if _, err := execute(c, "/bin/true"); err == nil { t.Fatalf("Container %q was not stopped after gofer death", c.ID) } } @@ -1700,12 +1691,11 @@ func TestMultiContainerRunNonRoot(t *testing.T) { } // TestMultiContainerHomeEnvDir tests that the HOME environment variable is set -// for root containers, sub-containers, and execed processes. +// for root containers, sub-containers, and exec'ed processes. func TestMultiContainerHomeEnvDir(t *testing.T) { - // TODO(gvisor.dev/issue/1487): VFSv2 configs failing. // NOTE: Don't use overlay since we need changes to persist to the temp dir // outside the sandbox. - for testName, conf := range configs(t, noOverlay...) { + for testName, conf := range configsWithVFS2(t, noOverlay...) { t.Run(testName, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() @@ -1725,12 +1715,11 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { homeDirs[name] = homeFile } - // We will sleep in the root container in order to ensure that - // the root container doesn't terminate before sub containers can be - // created. - rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())} - subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())} - execCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name())} + // We will sleep in the root container in order to ensure that the root + //container doesn't terminate before sub containers can be created. + rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf(`printf "$HOME" > %s; sleep 1000`, homeDirs["root"].Name())} + subCmd := []string{"/bin/sh", "-c", fmt.Sprintf(`printf "$HOME" > %s`, homeDirs["sub"].Name())} + execCmd := fmt.Sprintf(`printf "$HOME" > %s`, homeDirs["exec"].Name()) // Setup the containers, a root container and sub container. specConfig, ids := createSpecs(rootCmd, subCmd) @@ -1741,9 +1730,8 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { defer cleanup() // Exec into the root container synchronously. - args := &control.ExecArgs{Argv: execCmd} - if _, err := containers[0].executeSync(args); err != nil { - t.Errorf("error executing %+v: %v", args, err) + if _, err := execute(containers[0], "/bin/sh", "-c", execCmd); err != nil { + t.Errorf("error executing %+v: %v", execCmd, err) } // Wait for the subcontainer to finish. diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go index bac177a88..cb5bffb89 100644 --- a/runsc/container/shared_volume_test.go +++ b/runsc/container/shared_volume_test.go @@ -25,14 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" ) // TestSharedVolume checks that modifications to a volume mount are propagated // into and out of the sandbox. func TestSharedVolume(t *testing.T) { conf := testutil.TestConfig(t) - conf.FileAccess = boot.FileAccessShared + conf.FileAccess = config.FileAccessShared // Main process just sleeps. We will use "exec" to probe the state of // the filesystem. @@ -168,11 +168,7 @@ func TestSharedVolume(t *testing.T) { func checkFile(c *Container, filename string, want []byte) error { cpy := filename + ".copy" - argsCp := &control.ExecArgs{ - Filename: "/bin/cp", - Argv: []string{"cp", "-f", filename, cpy}, - } - if _, err := c.executeSync(argsCp); err != nil { + if _, err := execute(c, "/bin/cp", "-f", filename, cpy); err != nil { return fmt.Errorf("unexpected error copying file %q to %q: %v", filename, cpy, err) } got, err := ioutil.ReadFile(cpy) @@ -189,7 +185,7 @@ func checkFile(c *Container, filename string, want []byte) error { // is reflected inside. func TestSharedVolumeFile(t *testing.T) { conf := testutil.TestConfig(t) - conf.FileAccess = boot.FileAccessShared + conf.FileAccess = config.FileAccessShared // Main process just sleeps. We will use "exec" to probe the state of // the filesystem. @@ -235,11 +231,7 @@ func TestSharedVolumeFile(t *testing.T) { } // Append to file inside the container and check that content is not lost. - argsAppend := &control.ExecArgs{ - Filename: "/bin/bash", - Argv: []string{"bash", "-c", "echo -n sandbox- >> " + filename}, - } - if _, err := c.executeSync(argsAppend); err != nil { + if _, err := execute(c, "/bin/bash", "-c", "echo -n sandbox- >> "+filename); err != nil { t.Fatalf("unexpected error appending file %q: %v", filename, err) } want = []byte("host-sandbox-") diff --git a/runsc/flag/flag.go b/runsc/flag/flag.go index 0ca4829d7..ba1ff833f 100644 --- a/runsc/flag/flag.go +++ b/runsc/flag/flag.go @@ -21,13 +21,19 @@ import ( type FlagSet = flag.FlagSet var ( - NewFlagSet = flag.NewFlagSet - String = flag.String Bool = flag.Bool - Int = flag.Int - Uint = flag.Uint CommandLine = flag.CommandLine + Int = flag.Int + NewFlagSet = flag.NewFlagSet Parse = flag.Parse + String = flag.String + Uint = flag.Uint + Var = flag.Var ) const ContinueOnError = flag.ContinueOnError + +// Get returns the flag's underlying object. +func Get(v flag.Value) interface{} { + return v.(flag.Getter).Get() +} diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 1036b0630..96c57a426 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -31,5 +31,7 @@ go_test( deps = [ "//pkg/log", "//pkg/p9", + "//pkg/test/testutil", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 1dce36965..39b8a0b1e 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -27,62 +27,51 @@ import ( var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_ACCEPT: {}, syscall.SYS_CLOCK_GETTIME: {}, - syscall.SYS_CLONE: []seccomp.Rule{ - { - seccomp.AllowValue( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), - }, - }, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, - syscall.SYS_EPOLL_CTL: {}, + syscall.SYS_CLOSE: {}, + syscall.SYS_DUP: {}, + syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, }, syscall.SYS_EVENTFD2: []seccomp.Rule{ { - seccomp.AllowValue(0), - seccomp.AllowValue(0), + seccomp.EqualTo(0), + seccomp.EqualTo(0), }, }, syscall.SYS_EXIT: {}, syscall.SYS_EXIT_GROUP: {}, syscall.SYS_FALLOCATE: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, }, syscall.SYS_FCHMOD: {}, syscall.SYS_FCHOWNAT: {}, syscall.SYS_FCNTL: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_SETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_SETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFD), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFD), }, // Used by flipcall.PacketWindowAllocator.Init(). { - seccomp.AllowAny{}, - seccomp.AllowValue(unix.F_ADD_SEALS), + seccomp.MatchAny{}, + seccomp.EqualTo(unix.F_ADD_SEALS), }, }, syscall.SYS_FSTAT: {}, @@ -91,31 +80,31 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_FTRUNCATE: {}, syscall.SYS_FUTEX: { seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, // Non-private futex used for flipcall. seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, }, syscall.SYS_GETDENTS64: {}, @@ -128,6 +117,7 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_MADVISE: {}, unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init(). syscall.SYS_MKDIRAT: {}, + syscall.SYS_MKNODAT: {}, // Used by the Go runtime as a temporarily workaround for a Linux // 5.2-5.4 bug. // @@ -136,28 +126,28 @@ var allowedSyscalls = seccomp.SyscallRules{ // TODO(b/148688965): Remove once this is gone from Go. syscall.SYS_MLOCK: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(4096), + seccomp.MatchAny{}, + seccomp.EqualTo(4096), }, }, syscall.SYS_MMAP: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_SHARED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_SHARED), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), }, }, syscall.SYS_MPROTECT: {}, @@ -171,14 +161,14 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_READLINKAT: {}, syscall.SYS_RECVMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), }, }, syscall.SYS_RENAMEAT: {}, @@ -189,33 +179,33 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_SENDMSG: []seccomp.Rule{ // Used by fdchannel.Endpoint.SendFD(). { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, // Used by unet.SocketWriter.WriteVec(). { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, // Used by fdchannel.NewConnectedSockets(). syscall.SYS_SOCKETPAIR: { { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, }, syscall.SYS_SYMLINKAT: {}, syscall.SYS_TGKILL: []seccomp.Rule{ { - seccomp.AllowValue(uint64(os.Getpid())), + seccomp.EqualTo(uint64(os.Getpid())), }, }, syscall.SYS_UNLINKAT: {}, @@ -226,24 +216,24 @@ var allowedSyscalls = seccomp.SyscallRules{ var udsSyscalls = seccomp.SyscallRules{ syscall.SYS_SOCKET: []seccomp.Rule{ { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_STREAM), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_STREAM), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_DGRAM), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_DGRAM), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_SEQPACKET), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_SEQPACKET), + seccomp.EqualTo(0), }, }, syscall.SYS_CONNECT: []seccomp.Rule{ { - seccomp.AllowAny{}, + seccomp.MatchAny{}, }, }, } diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go index a4b28cb8b..686753d96 100644 --- a/runsc/fsgofer/filter/config_amd64.go +++ b/runsc/fsgofer/filter/config_amd64.go @@ -25,8 +25,41 @@ import ( func init() { allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_GET_FS)}, - {seccomp.AllowValue(linux.ARCH_SET_FS)}, + // TODO(b/168828518): No longer used in Go 1.16+. + {seccomp.EqualTo(linux.ARCH_SET_FS)}, + } + + allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + // parent_tidptr and child_tidptr are always 0 because neither + // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used. + { + seccomp.EqualTo( + syscall.CLONE_VM | + syscall.CLONE_FS | + syscall.CLONE_FILES | + syscall.CLONE_SETTLS | + syscall.CLONE_SIGHAND | + syscall.CLONE_SYSVSEM | + syscall.CLONE_THREAD), + seccomp.MatchAny{}, // newsp + seccomp.EqualTo(0), // parent_tidptr + seccomp.EqualTo(0), // child_tidptr + seccomp.MatchAny{}, // tls + }, + { + // TODO(b/168828518): No longer used in Go 1.16+ (on amd64). + seccomp.EqualTo( + syscall.CLONE_VM | + syscall.CLONE_FS | + syscall.CLONE_FILES | + syscall.CLONE_SIGHAND | + syscall.CLONE_SYSVSEM | + syscall.CLONE_THREAD), + seccomp.MatchAny{}, // newsp + seccomp.EqualTo(0), // parent_tidptr + seccomp.EqualTo(0), // child_tidptr + seccomp.MatchAny{}, // tls + }, } allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go index d2697deb7..ff0cf77a0 100644 --- a/runsc/fsgofer/filter/config_arm64.go +++ b/runsc/fsgofer/filter/config_arm64.go @@ -23,5 +23,26 @@ import ( ) func init() { + allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + // parent_tidptr and child_tidptr are always 0 because neither + // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used. + { + seccomp.EqualTo( + syscall.CLONE_VM | + syscall.CLONE_FS | + syscall.CLONE_FILES | + syscall.CLONE_SIGHAND | + syscall.CLONE_SYSVSEM | + syscall.CLONE_THREAD), + seccomp.MatchAny{}, // newsp + // These arguments are left uninitialized by the Go + // runtime, so they may be anything (and are unused by + // the host). + seccomp.MatchAny{}, // parent_tidptr + seccomp.MatchAny{}, // tls + seccomp.MatchAny{}, // child_tidptr + }, + } + allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{} } diff --git a/runsc/fsgofer/filter/extra_filters_race.go b/runsc/fsgofer/filter/extra_filters_race.go index 885c92f7a..20a0732be 100644 --- a/runsc/fsgofer/filter/extra_filters_race.go +++ b/runsc/fsgofer/filter/extra_filters_race.go @@ -35,6 +35,7 @@ func instrumentationFilters() seccomp.SyscallRules { syscall.SYS_MUNLOCK: {}, syscall.SYS_NANOSLEEP: {}, syscall.SYS_OPEN: {}, + syscall.SYS_OPENAT: {}, syscall.SYS_SET_ROBUST_LIST: {}, // Used within glibc's malloc. syscall.SYS_TIME: {}, diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 74977c313..0b628c8ce 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -29,7 +29,6 @@ import ( "path/filepath" "runtime" "strconv" - "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -45,39 +44,11 @@ const ( // modes to ensure an unopened/closed file fails all mode checks. invalidMode = p9.OpenFlags(math.MaxUint32) - openFlags = syscall.O_NOFOLLOW | syscall.O_CLOEXEC -) - -type fileType int + openFlags = unix.O_NOFOLLOW | unix.O_CLOEXEC -const ( - regular fileType = iota - directory - symlink - socket - unknown + allowedOpenFlags = unix.O_TRUNC ) -// String implements fmt.Stringer. -func (f fileType) String() string { - switch f { - case regular: - return "regular" - case directory: - return "directory" - case symlink: - return "symlink" - case socket: - return "socket" - } - return "unknown" -} - -// ControlSocketAddr generates an abstract unix socket name for the given id. -func ControlSocketAddr(id string) string { - return fmt.Sprintf("\x00runsc-gofer.%s", id) -} - // Config sets configuration options for each attach point. type Config struct { // ROMount is set to true if this is a readonly mount. @@ -132,19 +103,19 @@ func (a *attachPoint) Attach() (p9.File, error) { return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix) } - f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { + f, readable, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { return fd.Open(a.prefix, openFlags|mode, 0) }) if err != nil { return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err) } - stat, err := stat(f.FD()) + stat, err := fstat(f.FD()) if err != nil { return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err) } - lf, err := newLocalFile(a, f, a.prefix, stat) + lf, err := newLocalFile(a, f, a.prefix, readable, stat) if err != nil { return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err) } @@ -153,7 +124,7 @@ func (a *attachPoint) Attach() (p9.File, error) { } // makeQID returns a unique QID for the given stat buffer. -func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { +func (a *attachPoint) makeQID(stat unix.Stat_t) p9.QID { a.deviceMu.Lock() defer a.deviceMu.Unlock() @@ -184,9 +155,7 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // localFile implements p9.File wrapping a local file. The underlying file // is opened during Walk() and stored in 'file' to be used with other // operations. The file is opened as readonly, unless it's a symlink or there is -// no read access, which requires O_PATH. 'file' is dup'ed when Walk(nil) is -// called to clone the file. This reduces the number of walks that need to be -// done by the host file system when files are reused. +// no read access, which requires O_PATH. // // The file may be reopened if the requested mode in Open() is not a subset of // current mode. Consequently, 'file' could have a mode wider than requested and @@ -198,13 +167,30 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // performance with 'overlay2' storage driver. overlay2 eagerly copies the // entire file up when it's opened in write mode, and would perform badly when // multiple files are only being opened for read (esp. startup). +// +// File operations must use "at" functions whenever possible: +// * Local operations must use AT_EMPTY_PATH: +// fchownat(fd, "", AT_EMPTY_PATH, ...), instead of chown(fullpath, ...) +// * Creation operations must use (fd + name): +// mkdirat(fd, name, ...), instead of mkdir(fullpath, ...) +// +// Apart from being faster, it also adds another layer of defense against +// symlink attacks (note that O_NOFOLLOW applies only to the last element in +// the path). +// +// The few exceptions where it cannot be done are: utimensat on symlinks, and +// Connect() for the socket address. type localFile struct { - p9.DefaultWalkGetAttr + p9.DisallowClientCalls // attachPoint is the attachPoint that serves this localFile. attachPoint *attachPoint - // hostPath will be safely updated by the Renamed hook. + // hostPath is the full path to the host file. It can be used for logging and + // the few cases where full path is required to operation the host file. In + // all other cases, use "file" directly. + // + // Note: it's safely updated by the Renamed hook. hostPath string // file is opened when localFile is created and it's never nil. It may be @@ -212,12 +198,19 @@ type localFile struct { // opened with. file *fd.FD + // controlReadable tells whether 'file' was opened with read permissions + // during a walk. + controlReadable bool + // mode is the mode in which the file was opened. Set to invalidMode // if localFile isn't opened. mode p9.OpenFlags - // ft is the fileType for this file. - ft fileType + // fileType for this file. It is equivalent to: + // unix.Stat_t.Mode & unix.S_IFMT + fileType uint32 + + qid p9.QID // readDirMu protects against concurrent Readdir calls. readDirMu sync.Mutex @@ -234,7 +227,7 @@ var procSelfFD *fd.FD // OpenProcSelfFD opens the /proc/self/fd directory, which will be used to // reopen file descriptors. func OpenProcSelfFD() error { - d, err := syscall.Open("/proc/self/fd", syscall.O_RDONLY|syscall.O_DIRECTORY, 0) + d, err := unix.Open("/proc/self/fd", unix.O_RDONLY|unix.O_DIRECTORY, 0) if err != nil { return fmt.Errorf("error opening /proc/self/fd: %v", err) } @@ -243,7 +236,7 @@ func OpenProcSelfFD() error { } func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { - d, err := syscall.Openat(int(procSelfFD.FD()), strconv.Itoa(f.FD()), mode&^syscall.O_NOFOLLOW, 0) + d, err := unix.Openat(int(procSelfFD.FD()), strconv.Itoa(f.FD()), mode&^unix.O_NOFOLLOW, 0) if err != nil { return nil, err } @@ -251,83 +244,88 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { return fd.New(d), nil } -func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, error) { - path := path.Join(parent.hostPath, name) - f, err := openAnyFile(path, func(mode int) (*fd.FD, error) { +func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) { + pathDebug := path.Join(parent.hostPath, name) + f, readable, err := openAnyFile(pathDebug, func(mode int) (*fd.FD, error) { return fd.OpenAt(parent.file, name, openFlags|mode, 0) }) - return f, path, err + return f, pathDebug, readable, err } -// openAnyFile attempts to open the file in O_RDONLY and if it fails fallsback +// openAnyFile attempts to open the file in O_RDONLY. If it fails, falls back // to O_PATH. 'path' is used for logging messages only. 'fn' is what does the // actual file open and is customizable by the caller. -func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) { +func openAnyFile(pathDebug string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, error) { // Attempt to open file in the following mode in order: // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs. // Use non-blocking to prevent getting stuck inside open(2) for // FIFOs. This option has no effect on regular files. // 2. PATH: for symlinks, sockets. - modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH} + options := []struct { + mode int + readable bool + }{ + { + mode: unix.O_RDONLY | unix.O_NONBLOCK, + readable: true, + }, + { + mode: unix.O_PATH, + readable: false, + }, + } var err error - var file *fd.FD - for i, mode := range modes { - file, err = fn(mode) + for i, option := range options { + var file *fd.FD + file, err = fn(option.mode) if err == nil { - // openat succeeded, we're done. - break + // Succeeded opening the file, we're done. + return file, option.readable, nil } switch e := extractErrno(err); e { - case syscall.ENOENT: + case unix.ENOENT: // File doesn't exist, no point in retrying. - return nil, e + return nil, false, e } - // openat failed. Try again with next mode, preserving 'err' in case this - // was the last attempt. - log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|mode, path, err) + // File failed to open. Try again with next mode, preserving 'err' in case + // this was the last attempt. + log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|option.mode, pathDebug, err) } - if err != nil { - // All attempts to open file have failed, return the last error. - log.Debugf("Failed to open file, path: %q, err: %v", path, err) - return nil, extractErrno(err) - } - - return file, nil + // All attempts to open file have failed, return the last error. + log.Debugf("Failed to open file, path: %q, err: %v", pathDebug, err) + return nil, false, extractErrno(err) } -func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) { - var ft fileType - switch stat.Mode & syscall.S_IFMT { - case syscall.S_IFREG: - ft = regular - case syscall.S_IFDIR: - ft = directory - case syscall.S_IFLNK: - ft = symlink - case syscall.S_IFSOCK: +func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error { + switch stat.Mode & unix.S_IFMT { + case unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK: + return nil + + case unix.S_IFSOCK: if !permitSocket { - return unknown, syscall.EPERM + return unix.EPERM } - ft = socket + return nil + default: - return unknown, syscall.EPERM + return unix.EPERM } - return ft, nil } -func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) { - ft, err := getSupportedFileType(stat, a.conf.HostUDS) - if err != nil { +func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat unix.Stat_t) (*localFile, error) { + if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil { return nil, err } return &localFile{ - attachPoint: a, - hostPath: path, - file: file, - mode: invalidMode, - ft: ft, + attachPoint: a, + hostPath: path, + file: file, + mode: invalidMode, + fileType: stat.Mode & unix.S_IFMT, + qid: a.makeQID(stat), + controlReadable: readable, }, nil } @@ -335,7 +333,7 @@ func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) // non-blocking. If anything fails, returns nil. It's better to have a file // without host FD, than to fail the operation. func newFDMaybe(file *fd.FD) *fd.FD { - dupFD, err := syscall.Dup(file.FD()) + dupFD, err := unix.Dup(file.FD()) // Technically, the runtime may call the finalizer on file as soon as // FD() returns. runtime.KeepAlive(file) @@ -345,23 +343,23 @@ func newFDMaybe(file *fd.FD) *fd.FD { dup := fd.New(dupFD) // fd is blocking; non-blocking is required. - if err := syscall.SetNonblock(dup.FD(), true); err != nil { - dup.Close() + if err := unix.SetNonblock(dup.FD(), true); err != nil { + _ = dup.Close() return nil } return dup } -func stat(fd int) (syscall.Stat_t, error) { - var stat syscall.Stat_t - if err := syscall.Fstat(fd, &stat); err != nil { - return syscall.Stat_t{}, err +func fstat(fd int) (unix.Stat_t, error) { + var stat unix.Stat_t + if err := unix.Fstat(fd, &stat); err != nil { + return unix.Stat_t{}, err } return stat, nil } func fchown(fd int, uid p9.UID, gid p9.GID) error { - return syscall.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) + return unix.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) } // Open implements p9.File. @@ -369,10 +367,16 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { if l.isOpen() { panic(fmt.Sprintf("attempting to open already opened file: %q", l.hostPath)) } + mode := flags & p9.OpenFlagsModeMask + if mode == p9.WriteOnly || mode == p9.ReadWrite || flags&p9.OpenTruncate != 0 { + if err := l.checkROMount(); err != nil { + return nil, p9.QID{}, 0, err + } + } // Check if control file can be used or if a new open must be created. var newFile *fd.FD - if flags == p9.ReadOnly { + if mode == p9.ReadOnly && l.controlReadable && flags.OSFlags()&allowedOpenFlags == 0 { log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath) newFile = l.file } else { @@ -381,23 +385,15 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { // name_to_handle_at and open_by_handle_at aren't supported by overlay2. log.Debugf("Open reopening file, flags: %v, %q", flags, l.hostPath) var err error - // Constrain open flags to the open mode and O_TRUNC. - newFile, err = reopenProcFd(l.file, openFlags|(flags.OSFlags()&(syscall.O_ACCMODE|syscall.O_TRUNC))) + osFlags := flags.OSFlags() & (unix.O_ACCMODE | allowedOpenFlags) + newFile, err = reopenProcFd(l.file, openFlags|osFlags) if err != nil { return nil, p9.QID{}, 0, extractErrno(err) } } - stat, err := stat(newFile.FD()) - if err != nil { - if newFile != l.file { - newFile.Close() - } - return nil, p9.QID{}, 0, extractErrno(err) - } - var fd *fd.FD - if stat.Mode&syscall.S_IFMT == syscall.S_IFREG { + if l.fileType == unix.S_IFREG { // Donate FD for regular files only. fd = newFDMaybe(newFile) } @@ -409,38 +405,38 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { } l.file = newFile } - l.mode = flags & p9.OpenFlagsModeMask - return fd, l.attachPoint.makeQID(stat), 0, nil + l.mode = mode + return fd, l.qid, 0, nil } // Create implements p9.File. -func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9.File, p9.QID, uint32, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return nil, nil, p9.QID{}, 0, syscall.EBADF +func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9.File, p9.QID, uint32, error) { + if err := l.checkROMount(); err != nil { + return nil, nil, p9.QID{}, 0, err } + // Set file creation flags, plus allowed open flags from caller. + osFlags := openFlags | unix.O_CREAT | unix.O_EXCL + osFlags |= p9Flags.OSFlags() & allowedOpenFlags + // 'file' may be used for other operations (e.g. Walk), so read access is // always added to flags. Note that resulting file might have a wider mode // than needed for each particular case. - flags := openFlags | syscall.O_CREAT | syscall.O_EXCL + mode := p9Flags & p9.OpenFlagsModeMask if mode == p9.WriteOnly { - flags |= syscall.O_RDWR + osFlags |= unix.O_RDWR } else { - flags |= mode.OSFlags() + osFlags |= mode.OSFlags() } - child, err := fd.OpenAt(l.file, name, flags, uint32(perm.Permissions())) + child, err := fd.OpenAt(l.file, name, osFlags, uint32(perm.Permissions())) if err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } cu := cleanup.Make(func() { - child.Close() + _ = child.Close() // Best effort attempt to remove the file in case of failure. - if err := syscall.Unlinkat(l.file.FD(), name); err != nil { + if err := unix.Unlinkat(l.file.FD(), name, 0); err != nil { log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err) } }) @@ -449,7 +445,7 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid if err := fchown(child.FD(), uid, gid); err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } - stat, err := stat(child.FD()) + stat, err := fstat(child.FD()) if err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } @@ -459,23 +455,21 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid hostPath: path.Join(l.hostPath, name), file: child, mode: mode, + fileType: unix.S_IFREG, + qid: l.attachPoint.makeQID(stat), } cu.Release() - return newFDMaybe(c.file), c, l.attachPoint.makeQID(stat), 0, nil + return newFDMaybe(c.file), c, c.qid, 0, nil } // Mkdir implements p9.File. func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return p9.QID{}, syscall.EBADF + if err := l.checkROMount(); err != nil { + return p9.QID{}, err } - if err := syscall.Mkdirat(l.file.FD(), name, uint32(perm.Permissions())); err != nil { + if err := unix.Mkdirat(l.file.FD(), name, uint32(perm.Permissions())); err != nil { return p9.QID{}, extractErrno(err) } cu := cleanup.Make(func() { @@ -487,7 +481,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) defer cu.Clean() // Open directory to change ownership and stat it. - flags := syscall.O_DIRECTORY | syscall.O_RDONLY | openFlags + flags := unix.O_DIRECTORY | unix.O_RDONLY | openFlags f, err := fd.OpenAt(l.file, name, flags, 0) if err != nil { return p9.QID{}, extractErrno(err) @@ -497,7 +491,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) if err := fchown(f.FD(), uid, gid); err != nil { return p9.QID{}, extractErrno(err) } - stat, err := stat(f.FD()) + stat, err := fstat(f.FD()) if err != nil { return p9.QID{}, extractErrno(err) } @@ -508,61 +502,80 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) // Walk implements p9.File. func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { + qids, file, _, err := l.walk(names) + return qids, file, err +} + +// WalkGetAttr implements p9.File. +func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { + qids, file, stat, err := l.walk(names) + if err != nil { + return nil, nil, p9.AttrMask{}, p9.Attr{}, err + } + mask, attr := l.fillAttr(stat) + return qids, file, mask, attr, nil +} + +func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error) { // Duplicate current file if 'names' is empty. if len(names) == 0 { - newFile, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { + newFile, readable, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { return reopenProcFd(l.file, openFlags|mode) }) if err != nil { - return nil, nil, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } - stat, err := stat(newFile.FD()) + stat, err := fstat(newFile.FD()) if err != nil { - newFile.Close() - return nil, nil, extractErrno(err) + _ = newFile.Close() + return nil, nil, unix.Stat_t{}, extractErrno(err) } c := &localFile{ - attachPoint: l.attachPoint, - hostPath: l.hostPath, - file: newFile, - mode: invalidMode, + attachPoint: l.attachPoint, + hostPath: l.hostPath, + file: newFile, + mode: invalidMode, + fileType: l.fileType, + qid: l.attachPoint.makeQID(stat), + controlReadable: readable, } - return []p9.QID{l.attachPoint.makeQID(stat)}, c, nil + return []p9.QID{c.qid}, c, stat, nil } var qids []p9.QID + var lastStat unix.Stat_t last := l for _, name := range names { - f, path, err := openAnyFileFromParent(last, name) + f, path, readable, err := openAnyFileFromParent(last, name) if last != l { - last.Close() + _ = last.Close() } if err != nil { - return nil, nil, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } - stat, err := stat(f.FD()) + lastStat, err = fstat(f.FD()) if err != nil { - f.Close() - return nil, nil, extractErrno(err) + _ = f.Close() + return nil, nil, unix.Stat_t{}, extractErrno(err) } - c, err := newLocalFile(last.attachPoint, f, path, stat) + c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat) if err != nil { - f.Close() - return nil, nil, extractErrno(err) + _ = f.Close() + return nil, nil, unix.Stat_t{}, extractErrno(err) } - qids = append(qids, l.attachPoint.makeQID(stat)) + qids = append(qids, c.qid) last = c } - return qids, last, nil + return qids, last, lastStat, nil } // StatFS implements p9.File. func (l *localFile) StatFS() (p9.FSStat, error) { - var s syscall.Statfs_t - if err := syscall.Fstatfs(l.file.FD(), &s); err != nil { + var s unix.Statfs_t + if err := unix.Fstatfs(l.file.FD(), &s); err != nil { return p9.FSStat{}, extractErrno(err) } @@ -582,9 +595,9 @@ func (l *localFile) StatFS() (p9.FSStat, error) { // FSync implements p9.File. func (l *localFile) FSync() error { if !l.isOpen() { - return syscall.EBADF + return unix.EBADF } - if err := syscall.Fsync(l.file.FD()); err != nil { + if err := unix.Fsync(l.file.FD()); err != nil { return extractErrno(err) } return nil @@ -592,11 +605,15 @@ func (l *localFile) FSync() error { // GetAttr implements p9.File. func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - stat, err := stat(l.file.FD()) + stat, err := fstat(l.file.FD()) if err != nil { return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err) } + mask, attr := l.fillAttr(stat) + return l.qid, mask, attr, nil +} +func (l *localFile) fillAttr(stat unix.Stat_t) (p9.AttrMask, p9.Attr) { attr := p9.Attr{ Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), @@ -625,20 +642,15 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) MTime: true, CTime: true, } - - return l.attachPoint.makeQID(stat), valid, attr, nil + return valid, attr } // SetAttr implements p9.File. Due to mismatch in file API, options // cannot be changed atomically and user may see partial changes when // an error happens. func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } allowed := p9.SetAttrMask{ @@ -661,13 +673,13 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { // consistent result that is not attribute dependent. if !valid.IsSubsetOf(allowed) { log.Warningf("SetAttr() failed for %q, mask: %v", l.hostPath, valid) - return syscall.EPERM + return unix.EPERM } // Check if it's possible to use cached file, or if another one needs to be // opened for write. f := l.file - if l.ft == regular && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { + if l.fileType == unix.S_IFREG && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { var err error f, err = reopenProcFd(l.file, openFlags|os.O_WRONLY) if err != nil { @@ -688,21 +700,21 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { // over another. var err error if valid.Permissions { - if cerr := syscall.Fchmod(f.FD(), uint32(attr.Permissions)); cerr != nil { + if cerr := unix.Fchmod(f.FD(), uint32(attr.Permissions)); cerr != nil { log.Debugf("SetAttr fchmod failed %q, err: %v", l.hostPath, cerr) err = extractErrno(cerr) } } if valid.Size { - if terr := syscall.Ftruncate(f.FD(), int64(attr.Size)); terr != nil { + if terr := unix.Ftruncate(f.FD(), int64(attr.Size)); terr != nil { log.Debugf("SetAttr ftruncate failed %q, err: %v", l.hostPath, terr) err = extractErrno(terr) } } if valid.ATime || valid.MTime { - utimes := [2]syscall.Timespec{ + utimes := [2]unix.Timespec{ {Sec: 0, Nsec: linux.UTIME_OMIT}, {Sec: 0, Nsec: linux.UTIME_OMIT}, } @@ -723,15 +735,15 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } } - if l.ft == symlink { + if l.fileType == unix.S_IFLNK { // utimensat operates different that other syscalls. To operate on a // symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty // name. - parent, err := syscall.Open(path.Dir(l.hostPath), openFlags|unix.O_PATH, 0) + parent, err := unix.Open(path.Dir(l.hostPath), openFlags|unix.O_PATH, 0) if err != nil { return extractErrno(err) } - defer syscall.Close(parent) + defer unix.Close(parent) if terr := utimensat(parent, path.Base(l.hostPath), utimes, linux.AT_SYMLINK_NOFOLLOW); terr != nil { log.Debugf("SetAttr utimens failed %q, err: %v", l.hostPath, terr) @@ -756,7 +768,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { if valid.GID { gid = int(attr.GID) } - if oerr := syscall.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oerr != nil { + if oerr := unix.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oerr != nil { log.Debugf("SetAttr fchownat failed %q, err: %v", l.hostPath, oerr) err = extractErrno(oerr) } @@ -766,28 +778,28 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } func (*localFile) GetXattr(string, uint64) (string, error) { - return "", syscall.EOPNOTSUPP + return "", unix.EOPNOTSUPP } func (*localFile) SetXattr(string, string, uint32) error { - return syscall.EOPNOTSUPP + return unix.EOPNOTSUPP } func (*localFile) ListXattr(uint64) (map[string]struct{}, error) { - return nil, syscall.EOPNOTSUPP + return nil, unix.EOPNOTSUPP } func (*localFile) RemoveXattr(string) error { - return syscall.EOPNOTSUPP + return unix.EOPNOTSUPP } // Allocate implements p9.File. func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error { if !l.isOpen() { - return syscall.EBADF + return unix.EBADF } - if err := syscall.Fallocate(l.file.FD(), mode.ToLinux(), int64(offset), int64(length)); err != nil { + if err := unix.Fallocate(l.file.FD(), mode.ToLinux(), int64(offset), int64(length)); err != nil { return extractErrno(err) } return nil @@ -800,12 +812,8 @@ func (*localFile) Rename(p9.File, string) error { // RenameAt implements p9.File.RenameAt. func (l *localFile) RenameAt(oldName string, directory p9.File, newName string) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } newParent := directory.(*localFile) @@ -818,10 +826,10 @@ func (l *localFile) RenameAt(oldName string, directory p9.File, newName string) // ReadAt implements p9.File. func (l *localFile) ReadAt(p []byte, offset uint64) (int, error) { if l.mode != p9.ReadOnly && l.mode != p9.ReadWrite { - return 0, syscall.EBADF + return 0, unix.EBADF } if !l.isOpen() { - return 0, syscall.EBADF + return 0, unix.EBADF } r, err := l.file.ReadAt(p, int64(offset)) @@ -836,10 +844,10 @@ func (l *localFile) ReadAt(p []byte, offset uint64) (int, error) { // WriteAt implements p9.File. func (l *localFile) WriteAt(p []byte, offset uint64) (int, error) { if l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { - return 0, syscall.EBADF + return 0, unix.EBADF } if !l.isOpen() { - return 0, syscall.EBADF + return 0, unix.EBADF } w, err := l.file.WriteAt(p, int64(offset)) @@ -851,12 +859,8 @@ func (l *localFile) WriteAt(p []byte, offset uint64) (int, error) { // Symlink implements p9.File. func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return p9.QID{}, syscall.EBADF + if err := l.checkROMount(); err != nil { + return p9.QID{}, err } if err := unix.Symlinkat(target, l.file.FD(), newName); err != nil { @@ -864,7 +868,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. } cu := cleanup.Make(func() { // Best effort attempt to remove the symlink in case of failure. - if err := syscall.Unlinkat(l.file.FD(), newName); err != nil { + if err := unix.Unlinkat(l.file.FD(), newName, 0); err != nil { log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, newName), err) } }) @@ -880,7 +884,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. if err := fchown(f.FD(), uid, gid); err != nil { return p9.QID{}, extractErrno(err) } - stat, err := stat(f.FD()) + stat, err := fstat(f.FD()) if err != nil { return p9.QID{}, extractErrno(err) } @@ -891,12 +895,8 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. // Link implements p9.File. func (l *localFile) Link(target p9.File, newName string) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } targetFile := target.(*localFile) @@ -907,23 +907,53 @@ func (l *localFile) Link(target p9.File, newName string) error { } // Mknod implements p9.File. -// -// Not implemented. -func (*localFile) Mknod(_ string, _ p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) { +func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid p9.UID, gid p9.GID) (p9.QID, error) { + if err := l.checkROMount(); err != nil { + return p9.QID{}, err + } + // From mknod(2) man page: // "EPERM: [...] if the filesystem containing pathname does not support // the type of node requested." - return p9.QID{}, syscall.EPERM + if mode.FileType() != p9.ModeRegular { + return p9.QID{}, unix.EPERM + } + + // Allow Mknod to create regular files. + if err := unix.Mknodat(l.file.FD(), name, uint32(mode), 0); err != nil { + return p9.QID{}, err + } + cu := cleanup.Make(func() { + // Best effort attempt to remove the file in case of failure. + if err := unix.Unlinkat(l.file.FD(), name, 0); err != nil { + log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err) + } + }) + defer cu.Clean() + + // Open file to change ownership and stat it. + child, err := fd.OpenAt(l.file, name, unix.O_PATH|openFlags, 0) + if err != nil { + return p9.QID{}, extractErrno(err) + } + defer child.Close() + + if err := fchown(child.FD(), uid, gid); err != nil { + return p9.QID{}, extractErrno(err) + } + stat, err := fstat(child.FD()) + if err != nil { + return p9.QID{}, extractErrno(err) + } + + cu.Release() + return l.attachPoint.makeQID(stat), nil } // UnlinkAt implements p9.File. func (l *localFile) UnlinkAt(name string, flags uint32) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } if err := unix.Unlinkat(l.file.FD(), name, int(flags)); err != nil { @@ -935,10 +965,10 @@ func (l *localFile) UnlinkAt(name string, flags uint32) error { // Readdir implements p9.File. func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { if l.mode != p9.ReadOnly && l.mode != p9.ReadWrite { - return nil, syscall.EBADF + return nil, unix.EBADF } if !l.isOpen() { - return nil, syscall.EBADF + return nil, unix.EBADF } // Readdirnames is a cursor over directories, so seek back to 0 to ensure it's @@ -949,10 +979,13 @@ func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { skip := uint64(0) - // Check if the file is at the correct position already. If not, seek to the - // beginning and read the entire directory again. - if l.lastDirentOffset != offset { - if _, err := syscall.Seek(l.file.FD(), 0, 0); err != nil { + // Check if the file is at the correct position already. If not, seek to + // the beginning and read the entire directory again. We always seek if + // offset is 0, since this is side-effectual (equivalent to rewinddir(3), + // which causes the directory stream to resynchronize with the directory's + // current contents). + if l.lastDirentOffset != offset || offset == 0 { + if _, err := unix.Seek(l.file.FD(), 0, 0); err != nil { return nil, extractErrno(err) } skip = offset @@ -985,7 +1018,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) end := offset + uint64(count) for offset < end { - dirSize, err := syscall.ReadDirent(f, direntsBuf) + dirSize, err := unix.ReadDirent(f, direntsBuf) if err != nil { return dirents, err } @@ -994,7 +1027,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) } names := names[:0] - _, _, names = syscall.ParseDirent(direntsBuf[:dirSize], -1, names) + _, _, names = unix.ParseDirent(direntsBuf[:dirSize], -1, names) // Skip over entries that the caller is not interested in. if skip > 0 { @@ -1039,7 +1072,7 @@ func (l *localFile) Readlink() (string, error) { return string(b[:n]), nil } } - return "", syscall.ENOMEM + return "", unix.ENOMEM } // Flush implements p9.File. @@ -1050,7 +1083,7 @@ func (l *localFile) Flush() error { // Connect implements p9.File. func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { if !l.attachPoint.conf.HostUDS { - return nil, syscall.ECONNREFUSED + return nil, unix.ECONNREFUSED } // TODO(gvisor.dev/issue/1003): Due to different app vs replacement @@ -1058,34 +1091,34 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { // fit f.path in our sockaddr. We'd need to redirect through a shorter // path in order to actually connect to this socket. if len(l.hostPath) > linux.UnixPathMax { - return nil, syscall.ECONNREFUSED + return nil, unix.ECONNREFUSED } var stype int switch flags { case p9.StreamSocket: - stype = syscall.SOCK_STREAM + stype = unix.SOCK_STREAM case p9.DgramSocket: - stype = syscall.SOCK_DGRAM + stype = unix.SOCK_DGRAM case p9.SeqpacketSocket: - stype = syscall.SOCK_SEQPACKET + stype = unix.SOCK_SEQPACKET default: - return nil, syscall.ENXIO + return nil, unix.ENXIO } - f, err := syscall.Socket(syscall.AF_UNIX, stype, 0) + f, err := unix.Socket(unix.AF_UNIX, stype, 0) if err != nil { return nil, err } - if err := syscall.SetNonblock(f, true); err != nil { - syscall.Close(f) + if err := unix.SetNonblock(f, true); err != nil { + _ = unix.Close(f) return nil, err } - sa := syscall.SockaddrUnix{Name: l.hostPath} - if err := syscall.Connect(f, &sa); err != nil { - syscall.Close(f) + sa := unix.SockaddrUnix{Name: l.hostPath} + if err := unix.Connect(f, &sa); err != nil { + _ = unix.Close(f) return nil, err } @@ -1110,7 +1143,7 @@ func (l *localFile) Renamed(newDir p9.File, newName string) { } // extractErrno tries to determine the errno. -func extractErrno(err error) syscall.Errno { +func extractErrno(err error) unix.Errno { if err == nil { // This should never happen. The likely result will be that // some user gets the frustrating "error: SUCCESS" message. @@ -1120,18 +1153,18 @@ func extractErrno(err error) syscall.Errno { switch err { case os.ErrNotExist: - return syscall.ENOENT + return unix.ENOENT case os.ErrExist: - return syscall.EEXIST + return unix.EEXIST case os.ErrPermission: - return syscall.EACCES + return unix.EACCES case os.ErrInvalid: - return syscall.EINVAL + return unix.EINVAL } // See if it's an errno or a common wrapped error. switch e := err.(type) { - case syscall.Errno: + case unix.Errno: return e case *os.PathError: return extractErrno(e.Err) @@ -1143,5 +1176,12 @@ func extractErrno(err error) syscall.Errno { // Fall back to EIO. log.Debugf("Unknown error: %v, defaulting to EIO", err) - return syscall.EIO + return unix.EIO +} + +func (l *localFile) checkROMount() error { + if conf := l.attachPoint.conf; conf.ROMount { + return unix.EROFS + } + return nil } diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go index 5d4aab597..c46958185 100644 --- a/runsc/fsgofer/fsgofer_amd64_unsafe.go +++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go @@ -17,25 +17,25 @@ package fsgofer import ( - "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) -func statAt(dirFd int, name string) (syscall.Stat_t, error) { - nameBytes, err := syscall.BytePtrFromString(name) +func statAt(dirFd int, name string) (unix.Stat_t, error) { + nameBytes, err := unix.BytePtrFromString(name) if err != nil { - return syscall.Stat_t{}, err + return unix.Stat_t{}, err } namePtr := unsafe.Pointer(nameBytes) - var stat syscall.Stat_t + var stat unix.Stat_t statPtr := unsafe.Pointer(&stat) - if _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, + if _, _, errno := unix.Syscall6( + unix.SYS_NEWFSTATAT, uintptr(dirFd), uintptr(namePtr), uintptr(statPtr), @@ -43,7 +43,7 @@ func statAt(dirFd int, name string) (syscall.Stat_t, error) { 0, 0); errno != 0 { - return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + return unix.Stat_t{}, syserr.FromHost(errno).ToError() } return stat, nil } diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go index 8041fd352..491460718 100644 --- a/runsc/fsgofer/fsgofer_arm64_unsafe.go +++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go @@ -17,25 +17,25 @@ package fsgofer import ( - "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) -func statAt(dirFd int, name string) (syscall.Stat_t, error) { - nameBytes, err := syscall.BytePtrFromString(name) +func statAt(dirFd int, name string) (unix.Stat_t, error) { + nameBytes, err := unix.BytePtrFromString(name) if err != nil { - return syscall.Stat_t{}, err + return unix.Stat_t{}, err } namePtr := unsafe.Pointer(nameBytes) - var stat syscall.Stat_t + var stat unix.Stat_t statPtr := unsafe.Pointer(&stat) - if _, _, errno := syscall.Syscall6( - syscall.SYS_FSTATAT, + if _, _, errno := unix.Syscall6( + unix.SYS_FSTATAT, uintptr(dirFd), uintptr(namePtr), uintptr(statPtr), @@ -43,7 +43,7 @@ func statAt(dirFd int, name string) (syscall.Stat_t, error) { 0, 0); errno != 0 { - return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + return unix.Stat_t{}, syserr.FromHost(errno).ToError() } return stat, nil } diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index 05af7e397..a84206686 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -21,11 +21,24 @@ import ( "os" "path" "path/filepath" - "syscall" "testing" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} + +var ( + allTypes = []uint32{unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK} + + // allConfs is set in init(). + allConfs []Config + + rwConfs = []Config{{ROMount: false}} + roConfs = []Config{{ROMount: true}} ) func init() { @@ -39,6 +52,13 @@ func init() { } } +func configTestName(conf *Config) string { + if conf.ROMount { + return "ROMount" + } + return "RWMount" +} + func assertPanic(t *testing.T, f func()) { defer func() { if r := recover(); r == nil { @@ -63,7 +83,7 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { } want = append(want, b...) } else { - if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF { + if e, ok := err.(unix.Errno); !ok || e != unix.EBADF { return fmt.Errorf("WriteAt() should have failed, got: %d, want: EBADFD", err) } } @@ -81,78 +101,83 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { return fmt.Errorf("ReadAt() wrong data, got: %s, want: %s", string(rBuf), want) } } else { - if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF { + if e, ok := err.(unix.Errno); !ok || e != unix.EBADF { return fmt.Errorf("ReadAt() should have failed, got: %d, want: EBADFD", err) } } return nil } -var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} - -var ( - allTypes = []fileType{regular, directory, symlink} - - // allConfs is set in init() above. - allConfs []Config - - rwConfs = []Config{{ROMount: false}} - roConfs = []Config{{ROMount: true}} -) - type state struct { - root *localFile - file *localFile - conf Config - ft fileType + root *localFile + file *localFile + conf Config + fileType uint32 } func (s state) String() string { - return fmt.Sprintf("type(%v)", s.ft) + return fmt.Sprintf("type(%v)", s.fileType) +} + +func typeName(fileType uint32) string { + switch fileType { + case unix.S_IFREG: + return "file" + case unix.S_IFDIR: + return "directory" + case unix.S_IFLNK: + return "symlink" + default: + panic(fmt.Sprintf("invalid file type for test: %d", fileType)) + } } func runAll(t *testing.T, test func(*testing.T, state)) { runCustom(t, allTypes, allConfs, test) } -func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) { +func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.T, state)) { for _, c := range confs { - t.Logf("Config: %+v", c) - for _, ft := range types { - t.Logf("File type: %v", ft) + name := fmt.Sprintf("%s/%s", configTestName(&c), typeName(ft)) + t.Run(name, func(t *testing.T) { + path, name, err := setup(ft) + if err != nil { + t.Fatalf("%v", err) + } + defer os.RemoveAll(path) - path, name, err := setup(ft) - if err != nil { - t.Fatalf("%v", err) - } - defer os.RemoveAll(path) + a, err := NewAttachPoint(path, c) + if err != nil { + t.Fatalf("NewAttachPoint failed: %v", err) + } + root, err := a.Attach() + if err != nil { + t.Fatalf("Attach failed, err: %v", err) + } - a, err := NewAttachPoint(path, c) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - root, err := a.Attach() - if err != nil { - t.Fatalf("Attach failed, err: %v", err) - } + _, file, err := root.Walk([]string{name}) + if err != nil { + root.Close() + t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) + } - _, file, err := root.Walk([]string{name}) - if err != nil { + st := state{ + root: root.(*localFile), + file: file.(*localFile), + conf: c, + fileType: ft, + } + test(t, st) + file.Close() root.Close() - t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) - } - - st := state{root: root.(*localFile), file: file.(*localFile), conf: c, ft: ft} - test(t, st) - file.Close() - root.Close() + }) } } } -func setup(ft fileType) (string, string, error) { - path, err := ioutil.TempDir("", "root-") +func setup(fileType uint32) (string, string, error) { + path, err := ioutil.TempDir(testutil.TmpDir(), "root-") if err != nil { return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err) } @@ -169,26 +194,26 @@ func setup(ft fileType) (string, string, error) { defer root.Close() var name string - switch ft { - case regular: + switch fileType { + case unix.S_IFREG: name = "file" _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err) } defer f.Close() - case directory: + case unix.S_IFDIR: name = "dir" if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err) } - case symlink: + case unix.S_IFLNK: name = "symlink" if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err) } default: - panic(fmt.Sprintf("unknown file type %v", ft)) + panic(fmt.Sprintf("unknown file type %v", fileType)) } return path, name, nil } @@ -202,7 +227,7 @@ func createFile(dir *localFile, name string) (*localFile, error) { } func TestReadWrite(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -221,9 +246,13 @@ func TestReadWrite(t *testing.T) { if err != nil { t.Fatalf("%v: Walk(%s) failed, err: %v", s, "test", err) } - if _, _, _, err := l.Open(flags); err != nil { + fd, _, _, err := l.Open(flags) + if err != nil { t.Fatalf("%v: Open(%v) failed, err: %v", s, flags, err) } + if fd != nil { + defer fd.Close() + } if err := testReadWrite(l, flags, want); err != nil { t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err) } @@ -232,14 +261,14 @@ func TestReadWrite(t *testing.T) { } func TestCreate(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { for i, flags := range allOpenFlags { _, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { t.Fatalf("%v, %v: WriteAt() failed, err: %v", s, flags, err) } - if err := testReadWrite(l, flags, []byte{}); err != nil { + if err := testReadWrite(l, flags, nil); err != nil { t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err) } } @@ -249,7 +278,7 @@ func TestCreate(t *testing.T) { // TestReadWriteDup tests that a file opened in any mode can be dup'ed and // reopened in any other mode. func TestReadWriteDup(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -279,9 +308,13 @@ func TestReadWriteDup(t *testing.T) { t.Fatalf("%v: Walk(<empty>) failed: %v", s, err) } defer dup.Close() - if _, _, _, err := dup.Open(dupFlags); err != nil { + fd, _, _, err := dup.Open(dupFlags) + if err != nil { t.Fatalf("%v: Open(%v) failed: %v", s, flags, err) } + if fd != nil { + defer fd.Close() + } if err := testReadWrite(dup, dupFlags, want); err != nil { t.Fatalf("%v: testReadWrite(%v) failed: %v", s, dupFlags, err) } @@ -291,19 +324,45 @@ func TestReadWriteDup(t *testing.T) { } func TestUnopened(t *testing.T) { - runCustom(t, []fileType{regular}, allConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFREG}, allConfs, func(t *testing.T, s state) { b := []byte("foobar") - if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF { - t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.WriteAt(b, 0); err != unix.EBADF { + t.Errorf("%v: WriteAt() should have failed, got: %v, expected: unix.EBADF", s, err) } - if _, err := s.file.ReadAt(b, 0); err != syscall.EBADF { - t.Errorf("%v: ReadAt() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.ReadAt(b, 0); err != unix.EBADF { + t.Errorf("%v: ReadAt() should have failed, got: %v, expected: unix.EBADF", s, err) } - if _, err := s.file.Readdir(0, 100); err != syscall.EBADF { - t.Errorf("%v: Readdir() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.Readdir(0, 100); err != unix.EBADF { + t.Errorf("%v: Readdir() should have failed, got: %v, expected: unix.EBADF", s, err) } - if err := s.file.FSync(); err != syscall.EBADF { - t.Errorf("%v: FSync() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.FSync(); err != unix.EBADF { + t.Errorf("%v: FSync() should have failed, got: %v, expected: unix.EBADF", s, err) + } + }) +} + +// TestOpenOPath is a regression test to ensure that a file that cannot be open +// for read is allowed to be open. This was happening because the control file +// was open with O_PATH, but Open() was not checking for it and allowing the +// control file to be reused. +func TestOpenOPath(t *testing.T) { + runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s state) { + // Fist remove all permissions on the file. + if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil { + t.Fatalf("SetAttr(): %v", err) + } + // Then walk to the file again to open a new control file. + filename := filepath.Base(s.file.hostPath) + _, newFile, err := s.root.Walk([]string{filename}) + if err != nil { + t.Fatalf("root.Walk(%q): %v", filename, err) + } + + if newFile.(*localFile).controlReadable { + t.Fatalf("control file didn't open with O_PATH: %+v", newFile) + } + if _, _, _, err := newFile.Open(p9.ReadOnly); err != unix.EACCES { + t.Fatalf("Open() should have failed, got: %v, wanted: EACCES", err) } }) } @@ -324,7 +383,7 @@ func TestSetAttrPerm(t *testing.T) { valid := p9.SetAttrMask{Permissions: true} attr := p9.SetAttr{Permissions: 0777} got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink { + if s.fileType == unix.S_IFLNK { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -345,7 +404,7 @@ func TestSetAttrSize(t *testing.T) { valid := p9.SetAttrMask{Size: true} attr := p9.SetAttr{Size: size} got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink || s.ft == directory { + if s.fileType == unix.S_IFLNK || s.fileType == unix.S_IFDIR { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -427,9 +486,9 @@ func TestLink(t *testing.T) { } err = dir.Link(s.file, linkFile) - if s.ft == directory { - if err != syscall.EPERM { - t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err) + if s.fileType == unix.S_IFDIR { + if err != unix.EPERM { + t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: unix.EPERM", s, linkFile, err) } return } @@ -440,54 +499,64 @@ func TestLink(t *testing.T) { } func TestROMountChecks(t *testing.T) { + const want = unix.EROFS + uid := p9.UID(os.Getuid()) + gid := p9.GID(os.Getgid()) + runCustom(t, allTypes, roConfs, func(t *testing.T, s state) { - if _, _, _, _, err := s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: Create() should have failed, got: %v, expected: syscall.EBADF", s, err) + if s.fileType != unix.S_IFLNK { + if _, _, _, err := s.file.Open(p9.WriteOnly); err != want { + t.Errorf("Open() should have failed, got: %v, expected: %v", err, want) + } + if _, _, _, err := s.file.Open(p9.ReadWrite); err != want { + t.Errorf("Open() should have failed, got: %v, expected: %v", err, want) + } + if _, _, _, err := s.file.Open(p9.ReadOnly | p9.OpenTruncate); err != want { + t.Errorf("Open() should have failed, got: %v, expected: %v", err, want) + } + f, _, _, err := s.file.Open(p9.ReadOnly) + if err != nil { + t.Errorf("Open() failed: %v", err) + } + if f != nil { + _ = f.Close() + } } - if _, err := s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: MkDir() should have failed, got: %v, expected: syscall.EBADF", s, err) + + if _, _, _, _, err := s.file.Create("some_file", p9.ReadWrite, 0777, uid, gid); err != want { + t.Errorf("Create() should have failed, got: %v, expected: %v", err, want) } - if err := s.file.RenameAt("some_file", s.file, "other_file"); err != syscall.EBADF { - t.Errorf("%v: Rename() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.Mkdir("some_dir", 0777, uid, gid); err != want { + t.Errorf("MkDir() should have failed, got: %v, expected: %v", err, want) } - if _, err := s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: Symlink() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.RenameAt("some_file", s.file, "other_file"); err != want { + t.Errorf("Rename() should have failed, got: %v, expected: %v", err, want) } - if err := s.file.UnlinkAt("some_file", 0); err != syscall.EBADF { - t.Errorf("%v: UnlinkAt() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.Symlink("some_place", "some_symlink", uid, gid); err != want { + t.Errorf("Symlink() should have failed, got: %v, expected: %v", err, want) } - if err := s.file.Link(s.file, "some_link"); err != syscall.EBADF { - t.Errorf("%v: Link() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.UnlinkAt("some_file", 0); err != want { + t.Errorf("UnlinkAt() should have failed, got: %v, expected: %v", err, want) } - - valid := p9.SetAttrMask{Size: true} - attr := p9.SetAttr{Size: 0} - if err := s.file.SetAttr(valid, attr); err != syscall.EBADF { - t.Errorf("%v: SetAttr() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.Link(s.file, "some_link"); err != want { + t.Errorf("Link() should have failed, got: %v, expected: %v", err, want) + } + if _, err := s.file.Mknod("some-nod", 0777, 1, 2, uid, gid); err != want { + t.Errorf("Mknod() should have failed, got: %v, expected: %v", err, want) } - }) -} - -func TestROMountPanics(t *testing.T) { - conf := Config{ROMount: true, PanicOnWrite: true} - runCustom(t, allTypes, []Config{conf}, func(t *testing.T, s state) { - assertPanic(t, func() { s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.RenameAt("some_file", s.file, "other_file") }) - assertPanic(t, func() { s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.UnlinkAt("some_file", 0) }) - assertPanic(t, func() { s.file.Link(s.file, "some_link") }) valid := p9.SetAttrMask{Size: true} attr := p9.SetAttr{Size: 0} - assertPanic(t, func() { s.file.SetAttr(valid, attr) }) + if err := s.file.SetAttr(valid, attr); err != want { + t.Errorf("SetAttr() should have failed, got: %v, expected: %v", err, want) + } }) } func TestWalkNotFound(t *testing.T) { - runCustom(t, []fileType{directory}, allConfs, func(t *testing.T, s state) { - if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT { - t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err) + runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) { + if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT { + t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: unix.ENOENT", s, "nobody-here", err) } }) } @@ -506,7 +575,7 @@ func TestWalkDup(t *testing.T) { } func TestReaddir(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { name := "dir" if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err) @@ -631,7 +700,7 @@ func TestAttachInvalidType(t *testing.T) { defer os.RemoveAll(dir) fifo := filepath.Join(dir, "fifo") - if err := syscall.Mkfifo(fifo, 0755); err != nil { + if err := unix.Mkfifo(fifo, 0755); err != nil { t.Fatalf("Mkfifo(%q): %v", fifo, err) } @@ -690,3 +759,63 @@ func TestDoubleAttachError(t *testing.T) { t.Fatalf("Attach should have failed, got %v want non-nil", err) } } + +func TestTruncate(t *testing.T) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + child, err := createFile(s.file, "test") + if err != nil { + t.Fatalf("createFile() failed: %v", err) + } + defer child.Close() + want := []byte("foobar") + w, err := child.WriteAt(want, 0) + if err != nil { + t.Fatalf("Write() failed: %v", err) + } + if w != len(want) { + t.Fatalf("Write() was partial, got: %d, expected: %d", w, len(want)) + } + + _, l, err := s.file.Walk([]string{"test"}) + if err != nil { + t.Fatalf("Walk(%s) failed: %v", "test", err) + } + if _, _, _, err := l.Open(p9.ReadOnly | p9.OpenTruncate); err != nil { + t.Fatalf("Open() failed: %v", err) + } + _, mask, attr, err := l.GetAttr(p9.AttrMask{Size: true}) + if err != nil { + t.Fatalf("GetAttr() failed: %v", err) + } + if !mask.Size { + t.Fatalf("GetAttr() didn't return size: %+v", mask) + } + if attr.Size != 0 { + t.Fatalf("truncate didn't work, want: 0, got: %d", attr.Size) + } + }) +} + +func TestMknod(t *testing.T) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + _, err := s.file.Mknod("test", p9.ModeRegular|0777, 1, 2, p9.UID(os.Getuid()), p9.GID(os.Getgid())) + if err != nil { + t.Fatalf("Mknod() failed: %v", err) + } + + _, f, err := s.file.Walk([]string{"test"}) + if err != nil { + t.Fatalf("Walk() failed: %v", err) + } + fd, _, _, err := f.Open(p9.ReadWrite) + if err != nil { + t.Fatalf("Open() failed: %v", err) + } + if fd != nil { + defer fd.Close() + } + if err := testReadWrite(f, p9.ReadWrite, nil); err != nil { + t.Fatalf("testReadWrite() failed: %v", err) + } + }) +} diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go index 542b54365..f11fea40d 100644 --- a/runsc/fsgofer/fsgofer_unsafe.go +++ b/runsc/fsgofer/fsgofer_unsafe.go @@ -15,18 +15,18 @@ package fsgofer import ( - "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/syserr" ) -func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) error { +func utimensat(dirFd int, name string, times [2]unix.Timespec, flags int) error { // utimensat(2) doesn't accept empty name, instead name must be nil to make it // operate directly on 'dirFd' unlike other *at syscalls. var namePtr unsafe.Pointer if name != "" { - nameBytes, err := syscall.BytePtrFromString(name) + nameBytes, err := unix.BytePtrFromString(name) if err != nil { return err } @@ -35,8 +35,8 @@ func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) err timesPtr := unsafe.Pointer(×[0]) - if _, _, errno := syscall.Syscall6( - syscall.SYS_UTIMENSAT, + if _, _, errno := unix.Syscall6( + unix.SYS_UTIMENSAT, uintptr(dirFd), uintptr(namePtr), uintptr(timesPtr), @@ -52,7 +52,7 @@ func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) err func renameat(oldDirFD int, oldName string, newDirFD int, newName string) error { var oldNamePtr unsafe.Pointer if oldName != "" { - nameBytes, err := syscall.BytePtrFromString(oldName) + nameBytes, err := unix.BytePtrFromString(oldName) if err != nil { return err } @@ -60,15 +60,15 @@ func renameat(oldDirFD int, oldName string, newDirFD int, newName string) error } var newNamePtr unsafe.Pointer if newName != "" { - nameBytes, err := syscall.BytePtrFromString(newName) + nameBytes, err := unix.BytePtrFromString(newName) if err != nil { return err } newNamePtr = unsafe.Pointer(nameBytes) } - if _, _, errno := syscall.Syscall6( - syscall.SYS_RENAMEAT, + if _, _, errno := unix.Syscall6( + unix.SYS_RENAMEAT, uintptr(oldDirFD), uintptr(oldNamePtr), uintptr(newDirFD), diff --git a/runsc/main.go b/runsc/main.go index c9f47c579..ed244c4ba 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -23,8 +23,6 @@ import ( "io/ioutil" "os" "os/signal" - "path/filepath" - "strings" "syscall" "time" @@ -32,8 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/cmd" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" ) @@ -41,57 +39,17 @@ import ( var ( // Although these flags are not part of the OCI spec, they are used by // Docker, and thus should not be changed. - rootDir = flag.String("root", "", "root directory for storage of container state.") - logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout.") - logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") - debug = flag.Bool("debug", false, "enable debug logging.") - showVersion = flag.Bool("version", false, "show version and exit.") // TODO(gvisor.dev/issue/193): support systemd cgroups systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") + showVersion = flag.Bool("version", false, "show version and exit.") // These flags are unique to runsc, and are used to configure parts of the // system that are not covered by the runtime spec. // Debugging flags. - debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") - panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") - logPackets = flag.Bool("log-packets", false, "enable network packet logging.") - logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") - debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") - panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") - debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") - alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.") - - // Debugging flags: strace related - strace = flag.Bool("strace", false, "enable strace.") - straceSyscalls = flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") - straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") - - // Flags that control sandbox runtime behavior. - platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") - network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") - hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") - softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.") - txChecksumOffload = flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.") - rxChecksumOffload = flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.") - qDisc = flag.String("qdisc", "fifo", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") - fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") - fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") - overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") - overlayfsStaleRead = flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") - watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.") - panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") - profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") - netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") - numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") - rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") - referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.") - cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") - vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") - - // Test flags, not to be used outside tests, ever. - testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") - testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") + debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") + panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") ) func main() { @@ -135,6 +93,8 @@ func main() { subcommands.Register(new(cmd.Gofer), internalGroup) subcommands.Register(new(cmd.Statefile), internalGroup) + config.RegisterFlags() + // All subcommands must be registered before flag parsing. flag.Parse() @@ -146,6 +106,12 @@ func main() { os.Exit(0) } + // Create a new Config from the flags. + conf, err := config.NewFromFlags() + if err != nil { + cmd.Fatalf(err.Error()) + } + // TODO(gvisor.dev/issue/193): support systemd cgroups if *systemdCgroup { fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") @@ -156,102 +122,28 @@ func main() { if *logFD > -1 { errorLogger = os.NewFile(uintptr(*logFD), "error log file") - } else if *logFilename != "" { + } else if conf.LogFilename != "" { // We must set O_APPEND and not O_TRUNC because Docker passes // the same log file for all commands (and also parses these // log files), so we can't destroy them on each command. var err error - errorLogger, err = os.OpenFile(*logFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) if err != nil { - cmd.Fatalf("error opening log file %q: %v", *logFilename, err) + cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err) } } cmd.ErrorLogger = errorLogger - platformType := *platformName - if _, err := platform.Lookup(platformType); err != nil { - cmd.Fatalf("%v", err) - } - - fsAccess, err := boot.MakeFileAccessType(*fileAccess) - if err != nil { - cmd.Fatalf("%v", err) - } - - if fsAccess == boot.FileAccessShared && *overlay { - cmd.Fatalf("overlay flag is incompatible with shared file access") - } - - netType, err := boot.MakeNetworkType(*network) - if err != nil { + if _, err := platform.Lookup(conf.Platform); err != nil { cmd.Fatalf("%v", err) } - wa, err := boot.MakeWatchdogAction(*watchdogAction) - if err != nil { - cmd.Fatalf("%v", err) - } - - if *numNetworkChannels <= 0 { - cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels) - } - - refsLeakMode, err := boot.MakeRefsLeakMode(*referenceLeakMode) - if err != nil { - cmd.Fatalf("%v", err) - } - - queueingDiscipline, err := boot.MakeQueueingDiscipline(*qDisc) - if err != nil { - cmd.Fatalf("%s", err) - } - // Sets the reference leak check mode. Also set it in config below to // propagate it to child processes. - refs.SetLeakMode(refsLeakMode) - - // Create a new Config from the flags. - conf := &boot.Config{ - RootDir: *rootDir, - Debug: *debug, - LogFilename: *logFilename, - LogFormat: *logFormat, - DebugLog: *debugLog, - PanicLog: *panicLog, - DebugLogFormat: *debugLogFormat, - FileAccess: fsAccess, - FSGoferHostUDS: *fsGoferHostUDS, - Overlay: *overlay, - Network: netType, - HardwareGSO: *hardwareGSO, - SoftwareGSO: *softwareGSO, - TXChecksumOffload: *txChecksumOffload, - RXChecksumOffload: *rxChecksumOffload, - LogPackets: *logPackets, - Platform: platformType, - Strace: *strace, - StraceLogSize: *straceLogSize, - WatchdogAction: wa, - PanicSignal: *panicSignal, - ProfileEnable: *profile, - EnableRaw: *netRaw, - NumNetworkChannels: *numNetworkChannels, - Rootless: *rootless, - AlsoLogToStderr: *alsoLogToStderr, - ReferenceLeakMode: refsLeakMode, - OverlayfsStaleRead: *overlayfsStaleRead, - CPUNumFromQuota: *cpuNumFromQuota, - VFS2: *vfs2Enabled, - QDisc: queueingDiscipline, - TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, - TestOnlyTestNameEnv: *testOnlyTestNameEnv, - } - if len(*straceSyscalls) != 0 { - conf.StraceSyscalls = strings.Split(*straceSyscalls, ",") - } + refs.SetLeakMode(conf.ReferenceLeak) // Set up logging. - if *debug { + if conf.Debug { log.SetLevel(log.Debug) } @@ -273,14 +165,14 @@ func main() { if *debugLogFD > -1 { f := os.NewFile(uintptr(*debugLogFD), "debug log file") - e = newEmitter(*debugLogFormat, f) + e = newEmitter(conf.DebugLogFormat, f) - } else if *debugLog != "" { - f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */) + } else if conf.DebugLog != "" { + f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */) if err != nil { - cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err) + cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err) } - e = newEmitter(*debugLogFormat, f) + e = newEmitter(conf.DebugLogFormat, f) } else { // Stderr is reserved for the application, just discard the logs if no debug @@ -306,8 +198,8 @@ func main() { if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) } - } else if *alsoLogToStderr { - e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)} + } else if conf.AlsoLogToStderr { + e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)} } log.SetTarget(e) @@ -326,7 +218,7 @@ func main() { log.Infof("\t\tVFS2 enabled: %v", conf.VFS2) log.Infof("***************************") - if *testOnlyAllowRunAsCurrentUserWithoutChroot { + if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { // SIGTERM is sent to all processes if a test exceeds its // timeout and this case is handled by syscall_test_runner. log.Warningf("Block the TERM signal. This is only safe in tests!") @@ -362,11 +254,3 @@ func newEmitter(format string, logFile io.Writer) log.Emitter { cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format) panic("unreachable") } - -func init() { - // Set default root dir to something (hopefully) user-writeable. - *rootDir = "/var/run/runsc" - if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { - *rootDir = filepath.Join(runtimeDir, "runsc") - } -} diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 035dcd3e3..f0a551a1e 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -26,10 +26,11 @@ go_library( "//runsc/boot", "//runsc/boot/platforms", "//runsc/cgroup", + "//runsc/config", "//runsc/console", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@com_github_vishvananda_netlink//:go_default_library", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index deee619f3..0b9f39466 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -49,26 +50,26 @@ import ( // // Run the following container to test it: // docker run -di --runtime=runsc -p 8080:80 -v $PWD:/usr/local/apache2/htdocs/ httpd:2.4 -func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Config) error { +func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *config.Config) error { log.Infof("Setting up network") switch conf.Network { - case boot.NetworkNone: + case config.NetworkNone: log.Infof("Network is disabled, create loopback interface only") if err := createDefaultLoopbackInterface(conn); err != nil { return fmt.Errorf("creating default loopback interface: %v", err) } - case boot.NetworkSandbox: + case config.NetworkSandbox: // Build the path to the net namespace of the sandbox process. // This is what we will copy. nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net") if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.TXChecksumOffload, conf.RXChecksumOffload, conf.NumNetworkChannels, conf.QDisc); err != nil { return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err) } - case boot.NetworkHost: + case config.NetworkHost: // Nothing to do here. default: - return fmt.Errorf("invalid network type: %d", conf.Network) + return fmt.Errorf("invalid network type: %v", conf.Network) } return nil } @@ -115,7 +116,7 @@ func isRootNS() (bool, error) { // createInterfacesAndRoutesFromNS scrapes the interface and routes from the // net namespace with the given path, creates them in the sandbox, and removes // them from the host. -func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, txChecksumOffload bool, rxChecksumOffload bool, numNetworkChannels int, qDisc boot.QueueingDiscipline) error { +func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, txChecksumOffload bool, rxChecksumOffload bool, numNetworkChannels int, qDisc config.QueueingDiscipline) error { // Join the network namespace that we will be copying. restore, err := joinNetNS(nsPath) if err != nil { @@ -134,7 +135,6 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG return err } if isRoot { - return fmt.Errorf("cannot run with network enabled in root network namespace") } diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 6e1a2af25..c4309feb3 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -41,6 +41,7 @@ import ( "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" "gvisor.dev/gvisor/runsc/cgroup" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/console" "gvisor.dev/gvisor/runsc/specutils" ) @@ -71,11 +72,14 @@ type Sandbox struct { // will have it as a child process. child bool - // status is an exit status of a sandbox process. - status syscall.WaitStatus - // statusMu protects status. statusMu sync.Mutex + + // status is the exit status of a sandbox process. It's only set if the + // child==true and the sandbox was waited on. This field allows for multiple + // threads to wait on sandbox and get the exit code, since Linux will return + // WaitStatus to one of the waiters only. + status syscall.WaitStatus } // Args is used to configure a new sandbox. @@ -116,7 +120,7 @@ type Args struct { // New creates the sandbox process. The caller must call Destroy() on the // sandbox. -func New(conf *boot.Config, args *Args) (*Sandbox, error) { +func New(conf *config.Config, args *Args) (*Sandbox, error) { s := &Sandbox{ID: args.ID, Cgroup: args.Cgroup} // The Cleanup object cleans up partially created sandboxes when an error // occurs. Any errors occurring during cleanup itself are ignored. @@ -180,7 +184,7 @@ func (s *Sandbox) CreateContainer(cid string) error { } // StartRoot starts running the root container process inside the sandbox. -func (s *Sandbox) StartRoot(spec *specs.Spec, conf *boot.Config) error { +func (s *Sandbox) StartRoot(spec *specs.Spec, conf *config.Config) error { log.Debugf("Start root sandbox %q, PID: %d", s.ID, s.Pid) conn, err := s.sandboxConnect() if err != nil { @@ -203,7 +207,7 @@ func (s *Sandbox) StartRoot(spec *specs.Spec, conf *boot.Config) error { } // StartContainer starts running a non-root container inside the sandbox. -func (s *Sandbox) StartContainer(spec *specs.Spec, conf *boot.Config, cid string, goferFiles []*os.File) error { +func (s *Sandbox) StartContainer(spec *specs.Spec, conf *config.Config, cid string, goferFiles []*os.File) error { for _, f := range goferFiles { defer f.Close() } @@ -232,7 +236,7 @@ func (s *Sandbox) StartContainer(spec *specs.Spec, conf *boot.Config, cid string } // Restore sends the restore call for a container in the sandbox. -func (s *Sandbox) Restore(cid string, spec *specs.Spec, conf *boot.Config, filename string) error { +func (s *Sandbox) Restore(cid string, spec *specs.Spec, conf *config.Config, filename string) error { log.Debugf("Restore sandbox %q", s.ID) rf, err := os.Open(filename) @@ -344,7 +348,7 @@ func (s *Sandbox) connError(err error) error { // createSandboxProcess starts the sandbox as a subprocess by running the "boot" // command, passing in the bundle dir. -func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncFile *os.File) error { +func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyncFile *os.File) error { // nextFD is used to get unused FDs that we can pass to the sandbox. It // starts at 3 because 0, 1, and 2 are taken by stdin/out/err. nextFD := 3 @@ -477,12 +481,10 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF cmd.Stderr = nil // If the console control socket file is provided, then create a new - // pty master/slave pair and set the TTY on the sandbox process. - if args.ConsoleSocket != "" { - cmd.Args = append(cmd.Args, "--console=true") - + // pty master/replica pair and set the TTY on the sandbox process. + if args.Spec.Process.Terminal && args.ConsoleSocket != "" { // console.NewWithSocket will send the master on the given - // socket, and return the slave. + // socket, and return the replica. tty, err := console.NewWithSocket(args.ConsoleSocket) if err != nil { return fmt.Errorf("setting up console with socket %q: %v", args.ConsoleSocket, err) @@ -557,10 +559,10 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF // Joins the network namespace if network is enabled. the sandbox talks // directly to the host network, which may have been configured in the // namespace. - if ns, ok := specutils.GetNS(specs.NetworkNamespace, args.Spec); ok && conf.Network != boot.NetworkNone { + if ns, ok := specutils.GetNS(specs.NetworkNamespace, args.Spec); ok && conf.Network != config.NetworkNone { log.Infof("Sandbox will be started in the container's network namespace: %+v", ns) nss = append(nss, ns) - } else if conf.Network == boot.NetworkHost { + } else if conf.Network == config.NetworkHost { log.Infof("Sandbox will be started in the host network namespace") } else { log.Infof("Sandbox will be started in new network namespace") @@ -570,7 +572,7 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF // User namespace depends on the network type. Host network requires to run // inside the user namespace specified in the spec or the current namespace // if none is configured. - if conf.Network == boot.NetworkHost { + if conf.Network == config.NetworkHost { if userns, ok := specutils.GetNS(specs.UserNamespace, args.Spec); ok { log.Infof("Sandbox will be started in container's user namespace: %+v", userns) nss = append(nss, userns) @@ -747,35 +749,47 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF // Wait waits for the containerized process to exit, and returns its WaitStatus. func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { log.Debugf("Waiting for container %q in sandbox %q", cid, s.ID) - var ws syscall.WaitStatus if conn, err := s.sandboxConnect(); err != nil { - // The sandbox may have exited while before we had a chance to - // wait on it. + // The sandbox may have exited while before we had a chance to wait on it. + // There is nothing we can do for subcontainers. For the init container, we + // can try to get the sandbox exit code. + if !s.IsRootContainer(cid) { + return syscall.WaitStatus(0), err + } log.Warningf("Wait on container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } else { defer conn.Close() + // Try the Wait RPC to the sandbox. + var ws syscall.WaitStatus err = conn.Call(boot.ContainerWait, &cid, &ws) if err == nil { // It worked! return ws, nil } + // See comment above. + if !s.IsRootContainer(cid) { + return syscall.WaitStatus(0), err + } + // The sandbox may have exited after we connected, but before // or during the Wait RPC. log.Warningf("Wait RPC to container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } - // The sandbox may have already exited, or exited while handling the - // Wait RPC. The best we can do is ask Linux what the sandbox exit - // status was, since in most cases that will be the same as the - // container exit status. + // The sandbox may have already exited, or exited while handling the Wait RPC. + // The best we can do is ask Linux what the sandbox exit status was, since in + // most cases that will be the same as the container exit status. if err := s.waitForStopped(); err != nil { - return ws, err + return syscall.WaitStatus(0), err } if !s.child { - return ws, fmt.Errorf("sandbox no longer running and its exit status is unavailable") + return syscall.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable") } + + s.statusMu.Lock() + defer s.statusMu.Unlock() return s.status, nil } @@ -1014,26 +1028,6 @@ func (s *Sandbox) StopCPUProfile() error { return nil } -// GoroutineProfile writes a goroutine profile to the given file. -func (s *Sandbox) GoroutineProfile(f *os.File) error { - log.Debugf("Goroutine profile %q", s.ID) - conn, err := s.sandboxConnect() - if err != nil { - return err - } - defer conn.Close() - - opts := control.ProfileOpts{ - FilePayload: urpc.FilePayload{ - Files: []*os.File{f}, - }, - } - if err := conn.Call(boot.GoroutineProfile, &opts, nil); err != nil { - return fmt.Errorf("getting sandbox %q goroutine profile: %v", s.ID, err) - } - return nil -} - // BlockProfile writes a block profile to the given file. func (s *Sandbox) BlockProfile(f *os.File) error { log.Debugf("Block profile %q", s.ID) @@ -1201,7 +1195,7 @@ func deviceFileForPlatform(name string) (*os.File, error) { // checkBinaryPermissions verifies that the required binary bits are set on // the runsc executable. -func checkBinaryPermissions(conf *boot.Config) error { +func checkBinaryPermissions(conf *config.Config) error { // All platforms need the other exe bit neededBits := os.FileMode(0001) if conf.Platform == platforms.Ptrace { diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index 62d4f5113..679d8bc8e 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -16,9 +16,10 @@ go_library( "//pkg/bits", "//pkg/log", "//pkg/sentry/kernel/auth", + "//runsc/config", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_mohae_deepcopy//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], @@ -29,5 +30,5 @@ go_test( size = "small", srcs = ["specutils_test.go"], library = ":specutils", - deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"], + deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], ) diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD new file mode 100644 index 000000000..3520f2d6d --- /dev/null +++ b/runsc/specutils/seccomp/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "seccomp", + srcs = [ + "audit_amd64.go", + "audit_arm64.go", + "seccomp.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/abi/linux", + "//pkg/bpf", + "//pkg/log", + "//pkg/seccomp", + "//pkg/sentry/kernel", + "//pkg/sentry/syscalls/linux", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) + +go_test( + name = "seccomp_test", + size = "small", + srcs = ["seccomp_test.go"], + library = ":seccomp", + deps = [ + "//pkg/binary", + "//pkg/bpf", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/runsc/specutils/seccomp/audit_amd64.go b/runsc/specutils/seccomp/audit_amd64.go new file mode 100644 index 000000000..417cf4a7a --- /dev/null +++ b/runsc/specutils/seccomp/audit_amd64.go @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package seccomp + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) + +const ( + nativeArchAuditNo = linux.AUDIT_ARCH_X86_64 +) diff --git a/runsc/specutils/seccomp/audit_arm64.go b/runsc/specutils/seccomp/audit_arm64.go new file mode 100644 index 000000000..b727ceff2 --- /dev/null +++ b/runsc/specutils/seccomp/audit_arm64.go @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package seccomp + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) + +const ( + nativeArchAuditNo = linux.AUDIT_ARCH_AARCH64 +) diff --git a/runsc/specutils/seccomp/seccomp.go b/runsc/specutils/seccomp/seccomp.go new file mode 100644 index 000000000..5932f7a41 --- /dev/null +++ b/runsc/specutils/seccomp/seccomp.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package seccomp implements some features of libseccomp in order to support +// OCI. +package seccomp + +import ( + "fmt" + "syscall" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/seccomp" + "gvisor.dev/gvisor/pkg/sentry/kernel" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" +) + +var ( + killThreadAction = linux.SECCOMP_RET_KILL_THREAD + trapAction = linux.SECCOMP_RET_TRAP + // runc always returns EPERM as the errorcode for SECCOMP_RET_ERRNO + errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(syscall.EPERM)) + // runc always returns EPERM as the errorcode for SECCOMP_RET_TRACE + traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(syscall.EPERM)) + allowAction = linux.SECCOMP_RET_ALLOW +) + +// BuildProgram generates a bpf program based on the given OCI seccomp +// config. +func BuildProgram(s *specs.LinuxSeccomp) (bpf.Program, error) { + defaultAction, err := convertAction(s.DefaultAction) + if err != nil { + return bpf.Program{}, fmt.Errorf("secomp default action: %w", err) + } + ruleset, err := convertRules(s) + if err != nil { + return bpf.Program{}, fmt.Errorf("invalid seccomp rules: %w", err) + } + + instrs, err := seccomp.BuildProgram(ruleset, defaultAction, killThreadAction) + if err != nil { + return bpf.Program{}, fmt.Errorf("building seccomp program: %w", err) + } + + program, err := bpf.Compile(instrs) + if err != nil { + return bpf.Program{}, fmt.Errorf("compiling seccomp program: %w", err) + } + + return program, nil +} + +// lookupSyscallNo gets the syscall number for the syscall with the given name +// for the given architecture. +func lookupSyscallNo(arch uint32, name string) (uint32, error) { + var table *kernel.SyscallTable + switch arch { + case linux.AUDIT_ARCH_X86_64: + table = slinux.AMD64 + case linux.AUDIT_ARCH_AARCH64: + table = slinux.ARM64 + } + if table == nil { + return 0, fmt.Errorf("unsupported architecture: %d", arch) + } + n, err := table.LookupNo(name) + if err != nil { + return 0, err + } + return uint32(n), nil +} + +// convertAction converts a LinuxSeccompAction to BPFAction +func convertAction(act specs.LinuxSeccompAction) (linux.BPFAction, error) { + // TODO(gvisor.dev/issue/3124): Update specs package to include ActLog and ActKillProcess. + switch act { + case specs.ActKill: + return killThreadAction, nil + case specs.ActTrap: + return trapAction, nil + case specs.ActErrno: + return errnoAction, nil + case specs.ActTrace: + return traceAction, nil + case specs.ActAllow: + return allowAction, nil + default: + return 0, fmt.Errorf("invalid action: %v", act) + } +} + +// convertRules converts OCI linux seccomp rules into RuleSets that can be used by +// the seccomp package to build a seccomp program. +func convertRules(s *specs.LinuxSeccomp) ([]seccomp.RuleSet, error) { + // NOTE: Architectures are only really relevant when calling 32bit syscalls + // on a 64bit system. Since we don't support that in gVisor anyway, we + // ignore Architectures and only test against the native architecture. + + ruleset := []seccomp.RuleSet{} + + for _, syscall := range s.Syscalls { + sysRules := seccomp.NewSyscallRules() + + action, err := convertAction(syscall.Action) + if err != nil { + return nil, err + } + + // Args + rules, err := convertArgs(syscall.Args) + if err != nil { + return nil, err + } + + for _, name := range syscall.Names { + syscallNo, err := lookupSyscallNo(nativeArchAuditNo, name) + if err != nil { + // If there is an error looking up the syscall number, assume it is + // not supported on this architecture and ignore it. This is, for + // better or worse, what runc does. + log.Warningf("OCI seccomp: ignoring syscall %q", name) + continue + } + + for _, rule := range rules { + sysRules.AddRule(uintptr(syscallNo), rule) + } + } + + ruleset = append(ruleset, seccomp.RuleSet{ + Rules: sysRules, + Action: action, + }) + } + + return ruleset, nil +} + +// convertArgs converts an OCI seccomp argument rule to a list of seccomp.Rule. +func convertArgs(args []specs.LinuxSeccompArg) ([]seccomp.Rule, error) { + argCounts := make([]uint, 6) + + for _, arg := range args { + if arg.Index > 6 { + return nil, fmt.Errorf("invalid index: %d", arg.Index) + } + + argCounts[arg.Index]++ + } + + // NOTE: If multiple rules apply to the same argument (same index) the + // action is triggered if any one of the rules matches (OR). If not, then + // all rules much match in order to trigger the action (AND). This appears to + // be some kind of legacy behavior of runc that nevertheless needs to be + // supported to maintain compatibility. + + hasMultipleArgs := false + for _, count := range argCounts { + if count > 1 { + hasMultipleArgs = true + break + } + } + + if hasMultipleArgs { + rules := []seccomp.Rule{} + + // Old runc behavior - do this for compatibility. + // Add rules as ORs by adding separate Rules. + for _, arg := range args { + rule := seccomp.Rule{nil, nil, nil, nil, nil, nil} + + if err := convertRule(arg, &rule); err != nil { + return nil, err + } + + rules = append(rules, rule) + } + + return rules, nil + } + + // Add rules as ANDs by adding to the same Rule. + rule := seccomp.Rule{nil, nil, nil, nil, nil, nil} + for _, arg := range args { + if err := convertRule(arg, &rule); err != nil { + return nil, err + } + } + + return []seccomp.Rule{rule}, nil +} + +// convertRule converts and adds the arg to a rule. +func convertRule(arg specs.LinuxSeccompArg, rule *seccomp.Rule) error { + switch arg.Op { + case specs.OpEqualTo: + rule[arg.Index] = seccomp.EqualTo(arg.Value) + case specs.OpNotEqual: + rule[arg.Index] = seccomp.NotEqual(arg.Value) + case specs.OpGreaterThan: + rule[arg.Index] = seccomp.GreaterThan(arg.Value) + case specs.OpGreaterEqual: + rule[arg.Index] = seccomp.GreaterThanOrEqual(arg.Value) + case specs.OpLessThan: + rule[arg.Index] = seccomp.LessThan(arg.Value) + case specs.OpLessEqual: + rule[arg.Index] = seccomp.LessThanOrEqual(arg.Value) + case specs.OpMaskedEqual: + rule[arg.Index] = seccomp.MaskedEqual(uintptr(arg.Value), uintptr(arg.ValueTwo)) + default: + return fmt.Errorf("unsupported operand: %q", arg.Op) + } + return nil +} diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go new file mode 100644 index 000000000..850c237ba --- /dev/null +++ b/runsc/specutils/seccomp/seccomp_test.go @@ -0,0 +1,414 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccomp + +import ( + "fmt" + "syscall" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bpf" +) + +type seccompData struct { + nr uint32 + arch uint32 + instructionPointer uint64 + args [6]uint64 +} + +// asInput converts a seccompData to a bpf.Input. +func asInput(d seccompData) bpf.Input { + return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +} + +// testInput creates an Input struct with given seccomp input values. +func testInput(arch uint32, syscallName string, args *[6]uint64) bpf.Input { + syscallNo, err := lookupSyscallNo(arch, syscallName) + if err != nil { + // Assume tests set valid syscall names. + panic(err) + } + + if args == nil { + argArray := [6]uint64{0, 0, 0, 0, 0, 0} + args = &argArray + } + + data := seccompData{ + nr: syscallNo, + arch: arch, + args: *args, + } + + return asInput(data) +} + +// testCase holds a seccomp test case. +type testCase struct { + name string + config specs.LinuxSeccomp + input bpf.Input + expected uint32 +} + +var ( + // seccompTests is a list of speccomp test cases. + seccompTests = []testCase{ + { + name: "default_allow", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + }, + input: testInput(nativeArchAuditNo, "read", nil), + expected: uint32(allowAction), + }, + { + name: "default_deny", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActErrno, + }, + input: testInput(nativeArchAuditNo, "read", nil), + expected: uint32(errnoAction), + }, + { + name: "deny_arch", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + }, + Action: specs.ActErrno, + }, + }, + }, + // Syscall matches but the arch is AUDIT_ARCH_X86 so the return + // value is the bad arch action. + input: asInput(seccompData{nr: 183, arch: 0x40000003}), // + expected: uint32(killThreadAction), + }, + { + name: "match_name_errno", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + "chmod", + }, + Action: specs.ActErrno, + }, + { + Names: []string{ + "write", + }, + Action: specs.ActTrace, + }, + }, + }, + input: testInput(nativeArchAuditNo, "getcwd", nil), + expected: uint32(errnoAction), + }, + { + name: "match_name_trace", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + "chmod", + }, + Action: specs.ActErrno, + }, + { + Names: []string{ + "write", + }, + Action: specs.ActTrace, + }, + }, + }, + input: testInput(nativeArchAuditNo, "write", nil), + expected: uint32(traceAction), + }, + { + name: "no_match_name_allow", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + "chmod", + }, + Action: specs.ActErrno, + }, + { + Names: []string{ + "write", + }, + Action: specs.ActTrace, + }, + }, + }, + input: testInput(nativeArchAuditNo, "openat", nil), + expected: uint32(allowAction), + }, + { + name: "simple_match_args", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + expected: uint32(errnoAction), + }, + { + name: "match_args_or", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + Op: specs.OpEqualTo, + }, + { + Index: 0, + Value: syscall.CLONE_VM, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + expected: uint32(errnoAction), + }, + { + name: "match_args_and", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getsockopt", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 1, + Value: syscall.SOL_SOCKET, + Op: specs.OpEqualTo, + }, + { + Index: 2, + Value: syscall.SO_PEERCRED, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET, syscall.SO_PEERCRED}), + expected: uint32(errnoAction), + }, + { + name: "no_match_args_and", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getsockopt", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 1, + Value: syscall.SOL_SOCKET, + Op: specs.OpEqualTo, + }, + { + Index: 2, + Value: syscall.SO_PEERCRED, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET}), + expected: uint32(allowAction), + }, + { + name: "Simple args (no match)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_VM}), + expected: uint32(allowAction), + }, + { + name: "OpMaskedEqual (match)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + ValueTwo: syscall.CLONE_FS, + Op: specs.OpMaskedEqual, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS | syscall.CLONE_VM}), + expected: uint32(errnoAction), + }, + { + name: "OpMaskedEqual (no match)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS | syscall.CLONE_VM, + ValueTwo: syscall.CLONE_FS | syscall.CLONE_VM, + Op: specs.OpMaskedEqual, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + expected: uint32(allowAction), + }, + { + name: "OpMaskedEqual (clone)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActErrno, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + // This comes from the Docker default seccomp + // profile for clone. + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: 0x7e020000, + ValueTwo: 0x0, + Op: specs.OpMaskedEqual, + }, + }, + Action: specs.ActAllow, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{0x50f00}), + expected: uint32(allowAction), + }, + } +) + +// TestRunscSeccomp generates seccomp programs from OCI config and executes +// them using runsc's library, comparing against expected results. +func TestRunscSeccomp(t *testing.T) { + for _, tc := range seccompTests { + t.Run(tc.name, func(t *testing.T) { + runscProgram, err := BuildProgram(&tc.config) + if err != nil { + t.Fatalf("generating runsc BPF: %v", err) + } + + if err := checkProgram(runscProgram, tc.input, tc.expected); err != nil { + t.Fatalf("running runsc BPF: %v", err) + } + }) + } +} + +// checkProgram runs the given program over the given input and checks the +// result against the expected output. +func checkProgram(p bpf.Program, in bpf.Input, expected uint32) error { + result, err := bpf.Exec(p, in) + if err != nil { + return err + } + + if result != expected { + // Include a decoded version of the program in output for debugging purposes. + decoded, _ := bpf.DecodeProgram(p) + return fmt.Errorf("Unexpected result: got: %d, expected: %d\nBPF Program\n%s", result, expected, decoded) + } + + return nil +} diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 5015c3a84..0392e3e83 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/runsc/config" ) // ExePath must point to runsc binary, which is normally the same binary. It's @@ -110,11 +111,6 @@ func ValidateSpec(spec *specs.Spec) error { log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.") } - // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox. - if spec.Linux != nil && spec.Linux.Seccomp != nil { - log.Warningf("Seccomp spec is being ignored") - } - if spec.Linux != nil && spec.Linux.RootfsPropagation != "" { if err := validateRootfsPropagation(spec.Linux.RootfsPropagation); err != nil { return err @@ -161,18 +157,18 @@ func OpenSpec(bundleDir string) (*os.File, error) { // ReadSpec reads an OCI runtime spec from the given bundle directory. // ReadSpec also normalizes all potential relative paths into absolute // path, e.g. spec.Root.Path, mount.Source. -func ReadSpec(bundleDir string) (*specs.Spec, error) { +func ReadSpec(bundleDir string, conf *config.Config) (*specs.Spec, error) { specFile, err := OpenSpec(bundleDir) if err != nil { return nil, fmt.Errorf("error opening spec file %q: %v", filepath.Join(bundleDir, "config.json"), err) } defer specFile.Close() - return ReadSpecFromFile(bundleDir, specFile) + return ReadSpecFromFile(bundleDir, specFile, conf) } // ReadSpecFromFile reads an OCI runtime spec from the given File, and // normalizes all relative paths into absolute by prepending the bundle dir. -func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error) { +func ReadSpecFromFile(bundleDir string, specFile *os.File, conf *config.Config) (*specs.Spec, error) { if _, err := specFile.Seek(0, os.SEEK_SET); err != nil { return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err) } @@ -195,6 +191,20 @@ func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error) m.Source = absPath(bundleDir, m.Source) } } + + // Override flags using annotation to allow customization per sandbox + // instance. + for annotation, val := range spec.Annotations { + const flagPrefix = "dev.gvisor.flag." + if strings.HasPrefix(annotation, flagPrefix) { + name := annotation[len(flagPrefix):] + log.Infof("Overriding flag: %s=%q", name, val) + if err := conf.Override(name, val); err != nil { + return nil, err + } + } + } + return &spec, nil } diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh deleted file mode 100755 index e0f6df438..000000000 --- a/scripts/benchmark.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# gcloud may be installed as a "snap". If it is, include it in PATH. -declare -r snap="/snap/bin" -if [[ -d "${snap}" ]]; then - export PATH="${PATH}:${snap}" -fi - -# Make sure we can find gcloud and exit if not. -which gcloud - -# Exporting for subprocesses as GCP APIs and tools check this environmental -# variable for authentication. -export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}" - -gcloud auth activate-service-account \ - --key-file "${GOOGLE_APPLICATION_CREDENTIALS}" - -gcloud config set project ${PROJECT} -gcloud config set compute/zone ${ZONE} - -bazel run //benchmarks:benchmarks -- \ - --verbose \ - run-gcp \ - "(startup|absl)" \ - --internal \ - --runtime=runc \ - --runtime=runsc \ - --installers=head diff --git a/scripts/common.sh b/scripts/common.sh deleted file mode 100755 index 3ca699e4a..000000000 --- a/scripts/common.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeou pipefail - -# Get the path to the directory this script lives in. -# If this script is being called with `source`, $0 will be the path of the -# *sourcing* script, so we can't use `dirname $0` to find scripts in this -# directory. -if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then - declare -r script_dir="$(dirname "$BASH_SOURCE")" -else - declare -r script_dir="$(dirname "$0")" -fi - -source "${script_dir}/common_build.sh" - -# Ensure it attempts to collect logs in all cases. -trap collect_logs EXIT - -function set_runtime() { - RUNTIME=${1:-runsc} - RUNSC_BIN=/tmp/"${RUNTIME}"/runsc - RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs - RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% -} - -function test_runsc() { - test --test_arg=--runtime=${RUNTIME} "$@" -} - -function install_runsc_for_test() { - local -r test_name=$1 - shift - if [[ -z "${test_name}" ]]; then - echo "Missing mandatory test name" - exit 1 - fi - - # Add test to the name, so it doesn't conflict with other runtimes. - set_runtime $(find_branch_name)_"${test_name}" - - # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name - # down to the runtime. - install_runsc "${RUNTIME}" \ - --TESTONLY-test-name-env=RUNSC_TEST_NAME \ - --debug \ - --strace \ - --log-packets \ - "$@" -} - -# Installs the runsc with given runtime name. set_runtime must have been called -# to set runtime and logs location. -function install_runsc() { - local -r runtime=$1 - shift - - # Prepare the runtime binary. - local -r output=$(build //runsc) - mkdir -p "$(dirname ${RUNSC_BIN})" - cp -f "${output}" "${RUNSC_BIN}" - chmod 0755 "${RUNSC_BIN}" - - # Install the runtime. - sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@" - - # Clear old logs files that may exist. - sudo rm -f "${RUNSC_LOGS_DIR}"/'*' - - # Restart docker to pick up the new runtime configuration. - sudo systemctl restart docker -} diff --git a/scripts/common_build.sh b/scripts/common_build.sh deleted file mode 100755 index 0d9a191b5..000000000 --- a/scripts/common_build.sh +++ /dev/null @@ -1,116 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -which bazel -bazel version - -# Switch into the workspace; only necessary if run with kokoro. -if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then - cd git/repo -elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then - cd github/repo -fi - -# Set the standard bazel flags. -declare -a BAZEL_FLAGS=( - "--show_timestamps" - "--test_output=errors" - "--keep_going" - "--verbose_failures=true" -) -if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then - BAZEL_FLAGS+=( - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" - "--config=remote" - ) -fi -declare -r BAZEL_FLAGS - -# Wrap bazel. -function build() { - bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \ - | tee /dev/fd/2 \ - | grep -E '^ bazel-bin/' \ - | awk '{ print $1; }' -} - -function test() { - bazel test "${BAZEL_FLAGS[@]}" "$@" -} - -function run() { - local binary=$1 - shift - bazel run "${binary}" -- "$@" -} - -function run_as_root() { - local binary=$1 - shift - bazel run --run_under="sudo" "${binary}" -- "$@" -} - -function query() { - QUERY_RESULT=$(bazel query "$@") -} - -function collect_logs() { - # Zip out everything into a convenient form. - if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then - # Merge results files of all shards for each test suite. - for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do - junitparser merge `find $d -name test.xml` $d/test.xml - cat $d/shard_*_of_*/test.log > $d/test.log - if ls -ld $d/shard_*_of_*/test.outputs 2>/dev/null; then - zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs - fi - done - find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf - # Move test logs to Kokoro directory. tar is used to conveniently perform - # renames while moving files. - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - - # Collect sentry logs, if any. - if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then - # Check if the directory is empty or not (only the first line it needed). - local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) - if [[ "${logs}" ]]; then - local -r archive=runsc_logs_"${RUNTIME}".tar.gz - if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then - echo "runsc logs will be uploaded to:" - echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" - echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" - fi - time tar \ - --verbose \ - --create \ - --gzip \ - --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \ - --directory "${RUNSC_LOGS_DIR}" \ - . - fi - fi - fi -} - -function find_branch_name() { - git branch --show-current \ - || git rev-parse HEAD \ - || bazel info workspace \ - | xargs basename -} diff --git a/scripts/dev.sh b/scripts/dev.sh deleted file mode 100755 index a9107f33e..000000000 --- a/scripts/dev.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# common.sh sets '-x', but it's annoying to see so much output. -set +x - -# Defaults -declare -i REFRESH=0 -declare NAME=$(find_branch_name) - -while [[ $# -gt 0 ]]; do - case "$1" in - --refresh) - REFRESH=1 - ;; - --help) - echo "Use this script to build and install runsc with Docker." - echo - echo "usage: $0 [--refresh] [runtime_name]" - exit 1 - ;; - *) - NAME=$1 - ;; - esac - shift -done - -set_runtime "${NAME}" -echo -echo "Using runtime=${RUNTIME}" -echo - -echo Building runsc... -# Build first and fail on error. $() prevents "set -e" from reporting errors. -build //runsc -declare OUTPUT="$(build //runsc)" - -if [[ ${REFRESH} -eq 0 ]]; then - install_runsc "${RUNTIME}" --net-raw - install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets - install_runsc "${RUNTIME}-p" --net-raw --profile - - echo - echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed." - echo "Use --runtime="${RUNTIME}" with your Docker command." - echo " docker run --rm --runtime="${RUNTIME}" hello-world" - echo - echo "If you rebuild, use $0 --refresh." - -else - mkdir -p "$(dirname ${RUNSC_BIN})" - cp -f ${OUTPUT} "${RUNSC_BIN}" - chmod a+rx "${RUNSC_BIN}" - - echo - echo "Runtime ${RUNTIME} refreshed." -fi - -echo "Logs are in: ${RUNSC_LOGS_DIR}" diff --git a/scripts/do_tests.sh b/scripts/do_tests.sh deleted file mode 100755 index a3a387c37..000000000 --- a/scripts/do_tests.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Build runsc. -build //runsc - -# run runsc do without root privileges. -run //runsc --rootless do true -run //runsc --rootless --network=none do true - -# run runsc do with root privileges. -run_as_root //runsc do true diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh deleted file mode 100755 index dce0a4085..000000000 --- a/scripts/docker_tests.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -install_runsc_for_test docker -test_runsc //test/image:image_test //test/e2e:integration_test - -install_runsc_for_test docker --vfs2 -test_runsc //test/image:image_test --test_filter=.*TestHelloWorld diff --git a/scripts/go.sh b/scripts/go.sh deleted file mode 100755 index 626ed8fa4..000000000 --- a/scripts/go.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Build the go path. -build :gopath - -# Build the synthetic branch. -tools/go_branch.sh - -# Checkout the new branch. -git checkout go && git clean -f - -go version - -# Build everything. -go build ./... - -# Push, if required. -if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then - if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then - git config --global credential.helper cache - git credential approve <<EOF -protocol=https -host=github.com -username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}") -password=x-oauth-basic -EOF - fi - git push origin go:go -fi diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh deleted file mode 100755 index 992db50dd..000000000 --- a/scripts/hostnet_tests.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -# Install the runtime and perform basic tests. -install_runsc_for_test hostnet --network=host -test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh deleted file mode 100755 index 8299a7c8b..000000000 --- a/scripts/iptables_tests.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-iptables - -# Needed by ip6tables. -sudo modprobe ip6table_filter - -install_runsc_for_test iptables --net-raw -test //test/iptables:iptables_test "--test_arg=--runtime=runc" -test //test/iptables:iptables_test "--test_arg=--runtime=${RUNTIME}" diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh deleted file mode 100755 index bac9b9192..000000000 --- a/scripts/issue_reviver.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -DIR=$(dirname $0) -source "${DIR}"/common.sh - -# Provide a credential file if available. -export OAUTH_TOKEN_FILE="" -if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then - OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}" -fi - -REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd) -run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}" diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh deleted file mode 100755 index 619571c74..000000000 --- a/scripts/kvm_tests.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -# Ensure that KVM is loaded, and we can use it. -(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm -sudo chmod a+rw /dev/kvm - -# Run all KVM platform tests (locally). -run_as_root //pkg/sentry/platform/kvm:kvm_test - -# Install the KVM runtime and run all integration tests. -install_runsc_for_test kvm --platform=kvm -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh deleted file mode 100755 index 448864953..000000000 --- a/scripts/overlay_tests.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -# Install the runtime and perform basic tests. -install_runsc_for_test overlay --overlay -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh deleted file mode 100755 index 727503bce..000000000 --- a/scripts/packetdrill_tests.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-packetdrill - -install_runsc_for_test runsc-d -query "attr(tags, manual, tests(//test/packetdrill/...))" -test_runsc $QUERY_RESULT diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh deleted file mode 100755 index 51c11f23f..000000000 --- a/scripts/packetimpact_tests.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-packetimpact - -install_runsc_for_test runsc-d -query "attr(tags, packetimpact, tests(//test/packetimpact/...))" -test_runsc $QUERY_RESULT diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh deleted file mode 100755 index d629bf2aa..000000000 --- a/scripts/root_tests.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -# Reinstall the latest containerd shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim - -# Run the tests that require root. -install_runsc_for_test root -run_as_root //test/root:root_test --runtime=${RUNTIME} diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh deleted file mode 100755 index 350a59f7c..000000000 --- a/scripts/runtime_tests.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Check that a runtime is provided. -if [ ! -v RUNTIME_TEST_NAME ]; then - echo "Must set $RUNTIME_TEST_NAME" >&2 - exit 1 -fi - -install_runsc_for_test runtimes -test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test" diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh deleted file mode 100755 index 3a15050c2..000000000 --- a/scripts/simple_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Run all simple tests (locally). -test //pkg/... //runsc/... //tools/... //benchmarks/... //benchmarks/runner:runner_test diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh deleted file mode 100755 index c67f2fe5c..000000000 --- a/scripts/swgso_tests.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -# Install the runtime and perform basic tests. -install_runsc_for_test swgso --software-gso=true --gso=false -test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh deleted file mode 100755 index 0e5d86727..000000000 --- a/scripts/syscall_kvm_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Run all ptrace-variants of the system call tests. -test --test_tag_filters=runsc_kvm //test/syscalls/... diff --git a/scripts/syscall_tests.sh b/scripts/syscall_tests.sh deleted file mode 100755 index a131b2d50..000000000 --- a/scripts/syscall_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# Run all ptrace-variants of the system call tests. -test --test_tag_filters=runsc_ptrace //test/syscalls/... diff --git a/shim/BUILD b/shim/BUILD new file mode 100644 index 000000000..8d29c459b --- /dev/null +++ b/shim/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "pkg_tar") + +package(licenses = ["notice"]) + +pkg_tar( + name = "config", + srcs = [ + "runsc.toml", + ], + mode = "0644", + package_dir = "/etc/containerd", + visibility = [ + "//visibility:public", + ], +) diff --git a/shim/README.md b/shim/README.md new file mode 100644 index 000000000..75daf00ac --- /dev/null +++ b/shim/README.md @@ -0,0 +1,10 @@ +# Shim Overview + +Integration with containerd is done via a [shim][shims]. There are various shims +supported for different versions of [containerd][containerd]. + +- [Containerd 1.2+ (shim v2)](https://gvisor.dev/docs/user_guide/containerd/quick_start/) +- [Containerd 1.1 (shim v1)](https://gvisor.dev/docs/user_guide/containerd/containerd_11/) + +[containerd]: https://github.com/containerd/containerd +[shims]: https://iximiuz.com/en/posts/implementing-container-runtime-shim/ diff --git a/shim/runsc.toml b/shim/runsc.toml new file mode 100644 index 000000000..e1c7de1bb --- /dev/null +++ b/shim/runsc.toml @@ -0,0 +1,6 @@ +# This is an example configuration file for runsc. +# +# By default, it will be parsed from /etc/containerd/runsc.toml, but see the +# static path configured in v1/main.go. Note that the configuration mechanism +# for newer container shim versions is different: see the documentation in v2. +[runsc_config] diff --git a/shim/v1/BUILD b/shim/v1/BUILD new file mode 100644 index 000000000..4c9e2c2c6 --- /dev/null +++ b/shim/v1/BUILD @@ -0,0 +1,30 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "gvisor-containerd-shim", + srcs = [ + "api.go", + "config.go", + "main.go", + ], + static = True, + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/shim", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_ttrpc//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/shim/v1/api.go b/shim/v1/api.go new file mode 100644 index 000000000..2444d23f1 --- /dev/null +++ b/shim/v1/api.go @@ -0,0 +1,24 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + shim "github.com/containerd/containerd/runtime/v1/shim/v1" +) + +type KillRequest = shim.KillRequest + +var registerShimService = shim.RegisterShimService diff --git a/shim/v1/config.go b/shim/v1/config.go new file mode 100644 index 000000000..a72cc7754 --- /dev/null +++ b/shim/v1/config.go @@ -0,0 +1,40 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "github.com/BurntSushi/toml" + +// config is the configuration for gvisor containerd shim. +type config struct { + // RuncShim is the shim binary path for standard containerd-shim for runc. + // When the runtime is `runc`, gvisor containerd shim will exec current + // process to standard containerd-shim. This is a work around for containerd + // 1.1. In containerd 1.2, containerd will choose different containerd-shims + // based on runtime. + RuncShim string `toml:"runc_shim"` + // RunscConfig is configuration for runsc. The key value will be converted + // to runsc flags --key=value directly. + RunscConfig map[string]string `toml:"runsc_config"` +} + +// loadConfig load gvisor containerd shim config from config file. +func loadConfig(path string) (*config, error) { + var c config + _, err := toml.DecodeFile(path, &c) + if err != nil { + return &c, err + } + return &c, nil +} diff --git a/shim/v1/main.go b/shim/v1/main.go new file mode 100644 index 000000000..3159923af --- /dev/null +++ b/shim/v1/main.go @@ -0,0 +1,265 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/sys" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/ttrpc" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/shim" +) + +var ( + debugFlag bool + namespaceFlag string + socketFlag string + addressFlag string + workdirFlag string + runtimeRootFlag string + containerdBinaryFlag string + shimConfigFlag string +) + +// Containerd defaults to runc, unless another runtime is explicitly specified. +// We keep the same default to make the default behavior consistent. +const defaultRoot = "/run/containerd/runc" + +func init() { + flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") + flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") + flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") + flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") + flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data") + flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime") + + // Currently, the `containerd publish` utility is embedded in the + // daemon binary. The daemon invokes `containerd-shim + // -containerd-binary ...` with its own os.Executable() path. + flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)") + flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file") +} + +func main() { + flag.Parse() + + // This is a hack. Exec current process to run standard containerd-shim + // if runtime root is not `runsc`. We don't need this for shim v2 api. + if filepath.Base(runtimeRootFlag) != "runsc" { + if err := executeRuncShim(); err != nil { + fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) + os.Exit(1) + } + } + + // Run regular shim if needed. + if err := executeShim(); err != nil { + fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) + os.Exit(1) + } +} + +// executeRuncShim execs current process to a containerd-shim process and +// retains all flags and envs. +func executeRuncShim() error { + c, err := loadConfig(shimConfigFlag) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load shim config: %w", err) + } + shimPath := c.RuncShim + if shimPath == "" { + shimPath, err = exec.LookPath("containerd-shim") + if err != nil { + return fmt.Errorf("lookup containerd-shim failed: %w", err) + } + } + + args := append([]string{shimPath}, os.Args[1:]...) + if err := syscall.Exec(shimPath, args, os.Environ()); err != nil { + return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err) + } + return nil +} + +func executeShim() error { + // start handling signals as soon as possible so that things are + // properly reaped or if runtime exits before we hit the handler. + signals, err := setupSignals() + if err != nil { + return err + } + path, err := os.Getwd() + if err != nil { + return err + } + server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) + if err != nil { + return fmt.Errorf("failed creating server: %w", err) + } + c, err := loadConfig(shimConfigFlag) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load shim config: %w", err) + } + sv, err := shim.NewService( + shim.Config{ + Path: path, + Namespace: namespaceFlag, + WorkDir: workdirFlag, + RuntimeRoot: runtimeRootFlag, + RunscConfig: c.RunscConfig, + }, + &remoteEventsPublisher{address: addressFlag}, + ) + if err != nil { + return err + } + registerShimService(server, sv) + if err := serve(server, socketFlag); err != nil { + return err + } + return handleSignals(signals, server, sv) +} + +// serve serves the ttrpc API over a unix socket at the provided path this +// function does not block. +func serve(server *ttrpc.Server, path string) error { + var ( + l net.Listener + err error + ) + if path == "" { + l, err = net.FileListener(os.NewFile(3, "socket")) + path = "[inherited from parent]" + } else { + if len(path) > 106 { + return fmt.Errorf("%q: unix socket path too long (> 106)", path) + } + l, err = net.Listen("unix", "\x00"+path) + } + if err != nil { + return err + } + go func() { + defer l.Close() + err := server.Serve(context.Background(), l) + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + log.Fatalf("ttrpc server failure: %v", err) + } + }() + return nil +} + +// setupSignals creates a new signal handler for all signals and sets the shim +// as a sub-reaper so that the container processes are reparented. +func setupSignals() (chan os.Signal, error) { + signals := make(chan os.Signal, 32) + signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE) + // make sure runc is setup to use the monitor for waiting on processes. + // TODO(random-liu): Move shim/reaper.go to a separate package. + runsc.Monitor = reaper.Default + // Set the shim as the subreaper for all orphaned processes created by + // the container. + if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { + return nil, err + } + return signals, nil +} + +func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error { + var ( + termOnce sync.Once + done = make(chan struct{}) + ) + + for { + select { + case <-done: + return nil + case s := <-signals: + switch s { + case unix.SIGCHLD: + if _, err := sys.Reap(false); err != nil { + log.Printf("reap error: %v", err) + } + case unix.SIGTERM, unix.SIGINT: + go termOnce.Do(func() { + ctx := context.TODO() + if err := server.Shutdown(ctx); err != nil { + log.Printf("failed to shutdown server: %v", err) + } + // Ensure our child is dead if any. + sv.Kill(ctx, &KillRequest{ + Signal: uint32(syscall.SIGKILL), + All: true, + }) + sv.Delete(context.Background(), &types.Empty{}) + close(done) + }) + case unix.SIGPIPE: + } + } + } +} + +type remoteEventsPublisher struct { + address string +} + +func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { + ns, _ := namespaces.Namespace(ctx) + encoded, err := typeurl.MarshalAny(event) + if err != nil { + return err + } + data, err := encoded.Marshal() + if err != nil { + return err + } + cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns) + cmd.Stdin = bytes.NewReader(data) + c, err := reaper.Default.Start(cmd) + if err != nil { + return err + } + status, err := reaper.Default.Wait(cmd, c) + if err != nil { + return fmt.Errorf("failed to publish event: %w", err) + } + if status != 0 { + return fmt.Errorf("failed to publish event: status %d", status) + } + return nil +} diff --git a/shim/v2/BUILD b/shim/v2/BUILD new file mode 100644 index 000000000..8de9ac0ba --- /dev/null +++ b/shim/v2/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "containerd-shim-runsc-v1", + srcs = [ + "main.go", + ], + static = True, + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/shim/v2", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + ], +) diff --git a/shim/v2/main.go b/shim/v2/main.go new file mode 100644 index 000000000..753871eea --- /dev/null +++ b/shim/v2/main.go @@ -0,0 +1,26 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/containerd/containerd/runtime/v2/shim" + + "gvisor.dev/gvisor/pkg/shim/v2" +) + +func main() { + shim.Run("io.containerd.runsc.v1", v2.New) +} diff --git a/test/README.md b/test/README.md index 02bbf42ff..15b0f4c33 100644 --- a/test/README.md +++ b/test/README.md @@ -24,11 +24,11 @@ also used to run these tests in `kokoro`. To run image and integration tests, run: -`./scripts/docker_tests.sh` +`make docker-tests` To run root tests, run: -`./scripts/root_tests.sh` +`make root-tests` There are a few other interesting variations for image and integration tests: diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md new file mode 100644 index 000000000..d1bbabf6f --- /dev/null +++ b/test/benchmarks/README.md @@ -0,0 +1,157 @@ +# Benchmark tools + +This package and subpackages are for running macro benchmarks on `runsc`. They +are meant to replace the previous //benchmarks benchmark-tools written in +python. + +Benchmarks are meant to look like regular golang benchmarks using the testing.B +library. + +## Setup + +To run benchmarks you will need: + +* Docker installed (17.09.0 or greater). + +The easiest way to setup runsc for running benchmarks is to use the make file. +From the root directory: + +* Download images: `make load-all-images` +* Install runsc suitable for benchmarking, which should probably not have + strace or debug logs enabled. For example:`make configure RUNTIME=myrunsc + ARGS=--platform=kvm`. +* Restart docker: `sudo service docker restart` + +You should now have a runtime with the following options configured in +`/etc/docker/daemon.json` + +``` +"myrunsc": { + "path": "/tmp/myrunsc/runsc", + "runtimeArgs": [ + "--debug-log", + "/tmp/bench/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%", + "--platform=kvm" + ] + }, + +``` + +This runtime has been configured with a debugging off and strace logs off and is +using kvm for demonstration. + +## Running benchmarks + +Given the runtime above runtime `myrunsc`, run benchmarks with the following: + +``` +make sudo TARGETS=//path/to:target ARGS="--runtime=myrunsc -test.v \ + -test.bench=." OPTIONS="-c opt +``` + +For example, to run only the Iperf tests: + +``` +make sudo TARGETS=//test/benchmarks/network:network_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=Iperf" OPTIONS="-c opt" +``` + +Benchmarks are run with root as some benchmarks require root privileges to do +things like drop caches. + +## Writing benchmarks + +Benchmarks consist of docker images as Dockerfiles and golang testing.B +benchmarks. + +### Dockerfiles: + +* Are stored at //images. +* New Dockerfiles go in an appropriately named directory at + `//images/benchmarks/my-cool-dockerfile`. +* Dockerfiles for benchmarks should: + * Use explicitly versioned packages. + * Not use ENV and CMD statements...it is easy to add these in the API. +* Note: A common pattern for getting access to a tmpfs mount is to copy files + there after container start. See: //test/benchmarks/build/bazel_test.go. You + can also make your own with `RunOpts.Mounts`. + +### testing.B packages + +In general, benchmarks should look like this: + +```golang + +var h harness.Harness + +func BenchmarkMyCoolOne(b *testing.B) { + machine, err := h.GetMachine() + // check err + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + b.ResetTimer() + + //Respect b.N. + for i := 0; i < b.N; i++ { + out, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/my-cool-image", + Env: []string{"MY_VAR=awesome"}, + other options...see dockerutil + }, "sh", "-c", "echo MY_VAR") + //check err + b.StopTimer() + + // Do parsing and reporting outside of the timer. + number := parseMyMetric(out) + b.ReportMetric(number, "my-cool-custom-metric") + + b.StartTimer() + } +} + +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} +``` + +Some notes on the above: + +* The harness is initiated in the TestMain method and made global to test + module. The harness will handle any presetup that needs to happen with + flags, remote virtual machines (eventually), and other services. +* Respect `b.N` in that users of the benchmark may want to "run for an hour" + or something of the sort. +* Use the `b.ReportMetric()` method to report custom metrics. +* Set the timer if time is useful for reporting. There isn't a way to turn off + default metrics in testing.B (B/op, allocs/op, ns/op). +* Take a look at dockerutil at //pkg/test/dockerutil to see all methods + available from containers. The API is based on the "official" + [docker API for golang](https://pkg.go.dev/mod/github.com/docker/docker). +* `harness.GetMachine()` marks how many machines this tests needs. If you have + a client and server and to mark them as multiple machines, call + `harness.GetMachine()` twice. + +## Profiling + +For profiling, the runtime is required to have the `--profile` flag enabled. +This flag loosens seccomp filters so that the runtime can write profile data to +disk. This configuration is not recommended for production. + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile --platform=kvm --vfs2"`. The kvm and vfs2 flags are not + required, but are included for demonstration. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles fs_test test run: + +``` +make sudo TARGETS=//test/benchmarks/fs:fs_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD new file mode 100644 index 000000000..32c139204 --- /dev/null +++ b/test/benchmarks/base/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "base", + testonly = 1, + srcs = [ + "base.go", + ], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "base_test", + size = "large", + srcs = [ + "size_test.go", + "startup_test.go", + "sysbench_test.go", + ], + library = ":base", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/benchmarks/base/base.go b/test/benchmarks/base/base.go new file mode 100644 index 000000000..7bac52ff1 --- /dev/null +++ b/test/benchmarks/base/base.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package base holds base performance benchmarks. +package base + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var testHarness harness.Harness + +// TestMain is the main method for package network. +func TestMain(m *testing.M) { + testHarness.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go new file mode 100644 index 000000000..7d3877459 --- /dev/null +++ b/test/benchmarks/base/size_test.go @@ -0,0 +1,221 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkSizeEmpty creates N empty containers and reads memory usage from +// /proc/meminfo. +func BenchmarkSizeEmpty(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + meminfo := tools.Meminfo{} + ctx := context.Background() + containers := make([]*dockerutil.Container, 0, b.N) + + // DropCaches before the test. + harness.DropCaches(machine) + + // Check available memory on 'machine'. + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to get meminfo: %v", err) + } + + // Make N containers. + for i := 0; i < b.N; i++ { + container := machine.GetContainer(ctx, b) + containers = append(containers, container) + if err := container.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/alpine", + }, "sh", "-c", "echo Hello && sleep 1000"); err != nil { + cleanUpContainers(ctx, containers) + b.Fatalf("failed to run container: %v", err) + } + if _, err := container.WaitForOutputSubmatch(ctx, "Hello", 5*time.Second); err != nil { + cleanUpContainers(ctx, containers) + b.Fatalf("failed to read container output: %v", err) + } + } + + // Drop caches again before second measurement. + harness.DropCaches(machine) + + // Check available memory after containers are up. + after, err := machine.RunCommand(cmd, args...) + cleanUpContainers(ctx, containers) + if err != nil { + b.Fatalf("failed to get meminfo: %v", err) + } + meminfo.Report(b, before, after) +} + +// BenchmarkSizeNginx starts N containers running Nginx, checks that they're +// serving, and checks memory used based on /proc/meminfo. +func BenchmarkSizeNginx(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + // DropCaches for the first measurement. + harness.DropCaches(machine) + + // Measure MemAvailable before creating containers. + meminfo := tools.Meminfo{} + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + + // Make N Nginx containers. + ctx := context.Background() + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + } + const port = 80 + servers := startServers(ctx, b, + serverArgs{ + machine: machine, + port: port, + runOpts: runOpts, + cmd: []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"}, + }) + defer cleanUpContainers(ctx, servers) + + // DropCaches after servers are created. + harness.DropCaches(machine) + // Take after measurement. + after, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + meminfo.Report(b, before, after) +} + +// BenchmarkSizeNode starts N containers running a Node app, checks that +// they're serving, and checks memory used based on /proc/meminfo. +func BenchmarkSizeNode(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + // Make a redis instance for Node to connect. + ctx := context.Background() + redis, redisIP := redisInstance(ctx, b, machine) + defer redis.CleanUp(ctx) + + // DropCaches after redis is created. + harness.DropCaches(machine) + + // Take before measurement. + meminfo := tools.Meminfo{} + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo commend: %v", err) + } + + // Create N Node servers. + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + } + nodeCmd := []string{"node", "index.js", redisIP.String()} + const port = 8080 + servers := startServers(ctx, b, + serverArgs{ + machine: machine, + port: port, + runOpts: runOpts, + cmd: nodeCmd, + }) + defer cleanUpContainers(ctx, servers) + + // DropCaches after servers are created. + harness.DropCaches(machine) + // Take after measurement. + cmd, args = meminfo.MakeCmd() + after, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + meminfo.Report(b, before, after) +} + +// serverArgs wraps args for startServers and runServerWorkload. +type serverArgs struct { + machine harness.Machine + port int + runOpts dockerutil.RunOpts + cmd []string +} + +// startServers starts b.N containers defined by 'runOpts' and 'cmd' and uses +// 'machine' to check that each is up. +func startServers(ctx context.Context, b *testing.B, args serverArgs) []*dockerutil.Container { + b.Helper() + servers := make([]*dockerutil.Container, 0, b.N) + + // Create N servers and wait until each of them is serving. + for i := 0; i < b.N; i++ { + server := args.machine.GetContainer(ctx, b) + servers = append(servers, server) + if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to spawn node instance: %v", err) + } + + // Get the container IP. + servingIP, err := server.FindIP(ctx, false) + if err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to get ip from server: %v", err) + } + + // Wait until the server is up. + if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to wait for serving") + } + } + return servers +} + +// cleanUpContainers cleans up a slice of containers. +func cleanUpContainers(ctx context.Context, containers []*dockerutil.Container) { + for _, c := range containers { + if c != nil { + c.CleanUp(ctx) + } + } +} diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go new file mode 100644 index 000000000..c36a544db --- /dev/null +++ b/test/benchmarks/base/startup_test.go @@ -0,0 +1,155 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkStartEmpty times startup time for an empty container. +func BenchmarkStartupEmpty(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + for i := 0; i < b.N; i++ { + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/alpine", + }, "true"); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} + +// BenchmarkStartupNginx times startup for a Nginx instance. +// Time is measured from start until the first request is served. +func BenchmarkStartupNginx(b *testing.B) { + // The machine to hold Nginx and the Node Server. + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + } + runServerWorkload(ctx, b, + serverArgs{ + machine: machine, + runOpts: runOpts, + port: 80, + cmd: []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"}, + }) +} + +// BenchmarkStartupNode times startup for a Node application instance. +// Time is measured from start until the first request is served. +// Note that the Node app connects to a Redis instance before serving. +func BenchmarkStartupNode(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + redis, redisIP := redisInstance(ctx, b, machine) + defer redis.CleanUp(ctx) + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + } + + cmd := []string{"node", "index.js", redisIP.String()} + runServerWorkload(ctx, b, + serverArgs{ + machine: machine, + port: 8080, + runOpts: runOpts, + cmd: cmd, + }) +} + +// redisInstance returns a Redis container and its reachable IP. +func redisInstance(ctx context.Context, b *testing.B, machine harness.Machine) (*dockerutil.Container, net.IP) { + b.Helper() + // Spawn a redis instance for the app to use. + redis := machine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to spwan redis instance: %v", err) + } + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to get IP from redis instance: %v", err) + } + return redis, redisIP +} + +// runServerWorkload runs a server workload defined by 'runOpts' and 'cmd'. +// 'clientMachine' is used to connect to the server on 'serverMachine'. +func runServerWorkload(ctx context.Context, b *testing.B, args serverArgs) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := func() error { + server := args.machine.GetContainer(ctx, b) + defer func() { + b.StopTimer() + // Cleanup servers as we run so that we can go indefinitely. + server.CleanUp(ctx) + b.StartTimer() + }() + if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil { + return fmt.Errorf("failed to spawn node instance: %v", err) + } + + servingIP, err := server.FindIP(ctx, false) + if err != nil { + return fmt.Errorf("failed to get ip from server: %v", err) + } + + // Wait until the Client sees the server as up. + if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil { + return fmt.Errorf("failed to wait for serving: %v", err) + } + return nil + }(); err != nil { + b.Fatal(err) + } + } +} diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go new file mode 100644 index 000000000..6fb813640 --- /dev/null +++ b/test/benchmarks/base/sysbench_test.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +type testCase struct { + name string + test tools.Sysbench +} + +// BenchmarSysbench runs sysbench on the runtime. +func BenchmarkSysbench(b *testing.B) { + + testCases := []testCase{ + testCase{ + name: "CPU", + test: &tools.SysbenchCPU{ + Base: tools.SysbenchBase{ + Threads: 1, + Time: 5, + }, + MaxPrime: 50000, + }, + }, + testCase{ + name: "Memory", + test: &tools.SysbenchMemory{ + Base: tools.SysbenchBase{ + Threads: 1, + }, + BlockSize: "1M", + TotalSize: "500G", + }, + }, + testCase{ + name: "Mutex", + test: &tools.SysbenchMutex{ + Base: tools.SysbenchBase{ + Threads: 8, + }, + Loops: 1, + Locks: 10000000, + Num: 4, + }, + }, + } + + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + + ctx := context.Background() + sysbench := machine.GetContainer(ctx, b) + defer sysbench.CleanUp(ctx) + + out, err := sysbench.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/sysbench", + }, tc.test.MakeCmd()...) + if err != nil { + b.Fatalf("failed to run sysbench: %v: logs:%s", err, out) + } + tc.test.Report(b, out) + }) + } +} diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD new file mode 100644 index 000000000..93b380e8a --- /dev/null +++ b/test/benchmarks/database/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "database", + testonly = 1, + srcs = ["database.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "database_test", + size = "enormous", + srcs = ["redis_test.go"], + library = ":database", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go new file mode 100644 index 000000000..9eeb59f9a --- /dev/null +++ b/test/benchmarks/database/database.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package database holds benchmarks around database applications. +package database + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package database. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go new file mode 100644 index 000000000..6671a4969 --- /dev/null +++ b/test/benchmarks/database/redis_test.go @@ -0,0 +1,123 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "context" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// All possible operations from redis. Note: "ping" will +// run both PING_INLINE and PING_BUILD. +var operations []string = []string{ + "PING_INLINE", + "PING_BULK", + "SET", + "GET", + "INCR", + "LPUSH", + "RPUSH", + "LPOP", + "RPOP", + "SADD", + "HSET", + "SPOP", + "LRANGE_100", + "LRANGE_300", + "LRANGE_500", + "LRANGE_600", + "MSET", +} + +// BenchmarkRedis runs redis-benchmark against a redis instance and reports +// data in queries per second. Each is reported by named operation (e.g. LPUSH). +func BenchmarkRedis(b *testing.B) { + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Redis runs on port 6379 by default. + port := 6379 + ctx := context.Background() + + for _, operation := range operations { + b.Run(operation, func(b *testing.B) { + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + // The redis docker container takes no arguments to run a redis server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + Ports: []int{port}, + }); err != nil { + b.Fatalf("failed to start redis server with: %v", err) + } + + if out, err := server.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get IP from server: %v", err) + } + + serverPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to get IP from server: %v", err) + } + + if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil { + b.Fatalf("failed to start redis with: %v", err) + } + + redis := tools.Redis{ + Operation: operation, + } + + // Reset profiles and timer to begin the measurement. + server.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }, redis.MakeCmd(ip, serverPort)...) + if err != nil { + b.Fatalf("redis-benchmark failed with: %v", err) + } + + // Stop time while we parse results. + b.StopTimer() + redis.Report(b, out) + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD new file mode 100644 index 000000000..45f11372b --- /dev/null +++ b/test/benchmarks/fs/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "fs", + testonly = 1, + srcs = ["fs.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "fs_test", + size = "large", + srcs = [ + "bazel_test.go", + "fio_test.go", + ], + library = ":fs", + tags = [ + # Requires docker and runsc to be configured before test runs. + "local", + "manual", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + "@com_github_docker_docker//api/types/mount:go_default_library", + ], +) diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go new file mode 100644 index 000000000..ef1b8e4ea --- /dev/null +++ b/test/benchmarks/fs/bazel_test.go @@ -0,0 +1,120 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package fs + +import ( + "context" + "fmt" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// Note: CleanCache versions of this test require running with root permissions. +func BenchmarkBuildABSL(b *testing.B) { + runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...") +} + +// Note: CleanCache versions of this test require running with root permissions. +// Note: This test takes on the order of 10m per permutation for runsc on kvm. +func BenchmarkBuildRunsc(b *testing.B) { + runBuildBenchmark(b, "benchmarks/runsc", "/gvisor", "runsc:runsc") +} + +func runBuildBenchmark(b *testing.B, image, workdir, target string) { + b.Helper() + // Get a machine from the Harness on which to run. + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + // Dimensions here are clean/dirty cache (do or don't drop caches) + // and if the mount on which we are compiling is a tmpfs/bind mount. + benchmarks := []struct { + name string + clearCache bool // clearCache drops caches before running. + tmpfs bool // tmpfs will run compilation on a tmpfs. + }{ + {name: "CleanCache", clearCache: true, tmpfs: false}, + {name: "DirtyCache", clearCache: false, tmpfs: false}, + {name: "CleanCacheTmpfs", clearCache: true, tmpfs: true}, + {name: "DirtyCacheTmpfs", clearCache: false, tmpfs: true}, + } + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + // Grab a container. + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + // Start a container and sleep. + if err := container.Spawn(ctx, dockerutil.RunOpts{ + Image: image, + }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { + b.Fatalf("run failed with: %v", err) + } + + // If we are running on a tmpfs, copy to /tmp which is a tmpfs. + prefix := "" + if bm.tmpfs { + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + "cp", "-r", workdir, "/tmp/."); err != nil { + b.Fatalf("failed to copy directory: %v (%s)", err, out) + } + prefix = "/tmp" + } + + // Restart profiles after the copy. + container.RestartProfiles() + b.ResetTimer() + // Drop Caches and bazel clean should happen inside the loop as we may use + // time options with b.N. (e.g. Run for an hour.) + for i := 0; i < b.N; i++ { + b.StopTimer() + // Drop Caches for clear cache runs. + if bm.clearCache { + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + } + b.StartTimer() + + got, err := container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: prefix + workdir, + }, "bazel", "build", "-c", "opt", target) + if err != nil { + b.Fatalf("build failed with: %v", err) + } + b.StopTimer() + + want := "Build completed successfully" + if !strings.Contains(got, want) { + b.Fatalf("string %s not in: %s", want, got) + } + // Clean bazel in case we use b.N. + _, err = container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: prefix + workdir, + }, "bazel", "clean") + if err != nil { + b.Fatalf("build failed with: %v", err) + } + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go new file mode 100644 index 000000000..65874ed8b --- /dev/null +++ b/test/benchmarks/fs/fio_test.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package fs + +import ( + "context" + "fmt" + "path/filepath" + "strings" + "testing" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkFio runs fio on the runtime under test. There are 4 basic test +// cases each run on a tmpfs mount and a bind mount. Fio requires root so that +// caches can be dropped. +func BenchmarkFio(b *testing.B) { + testCases := []tools.Fio{ + tools.Fio{ + Test: "write", + Size: "5G", + Blocksize: "1M", + Iodepth: 4, + }, + tools.Fio{ + Test: "read", + Size: "5G", + Blocksize: "1M", + Iodepth: 4, + }, + tools.Fio{ + Test: "randwrite", + Size: "5G", + Blocksize: "4K", + Iodepth: 4, + Time: 30, + }, + tools.Fio{ + Test: "randread", + Size: "5G", + Blocksize: "4K", + Iodepth: 4, + Time: 30, + }, + } + + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} { + for _, tc := range testCases { + testName := strings.Title(tc.Test) + strings.Title(string(fsType)) + b.Run(testName, func(b *testing.B) { + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + // Directory and filename inside container where fio will read/write. + outdir := "/data" + outfile := filepath.Join(outdir, "test.txt") + + // Make the required mount and grab a cleanup for bind mounts + // as they are backed by a temp directory (mktemp). + mnt, mountCleanup, err := makeMount(machine, fsType, outdir) + if err != nil { + b.Fatalf("failed to make mount: %v", err) + } + defer mountCleanup() + + // Start the container with the mount. + if err := container.Spawn( + ctx, + dockerutil.RunOpts{ + Image: "benchmarks/fio", + Mounts: []mount.Mount{ + mnt, + }, + }, + // Sleep on the order of b.N. + "sleep", fmt.Sprintf("%d", 1000*b.N), + ); err != nil { + b.Fatalf("failed to start fio container with: %v", err) + } + + // For reads, we need a file to read so make one inside the container. + if strings.Contains(tc.Test, "read") { + fallocateCmd := fmt.Sprintf("fallocate -l %s %s", tc.Size, outfile) + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + strings.Split(fallocateCmd, " ")...); err != nil { + b.Fatalf("failed to create readable file on mount: %v, %s", err, out) + } + } + + // Drop caches just before running. + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches with %v. You probably need root.", err) + } + cmd := tc.MakeCmd(outfile) + container.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Run fio. + data, err := container.Exec(ctx, dockerutil.ExecOpts{}, cmd...) + if err != nil { + b.Fatalf("failed to run cmd %v: %v", cmd, err) + } + b.StopTimer() + tc.Report(b, data) + // If b.N is used (i.e. we run for an hour), we should drop caches + // after each run. + if err := harness.DropCaches(machine); err != nil { + b.Fatalf("failed to drop caches: %v", err) + } + b.StartTimer() + } + }) + } + } +} + +// makeMount makes a mount and cleanup based on the requested type. Bind +// and volume mounts are backed by a temp directory made with mktemp. +// tmpfs mounts require no such backing and are just made. +// It is up to the caller to call the returned cleanup. +func makeMount(machine harness.Machine, mountType mount.Type, target string) (mount.Mount, func(), error) { + switch mountType { + case mount.TypeVolume, mount.TypeBind: + dir, err := machine.RunCommand("mktemp", "-d") + if err != nil { + return mount.Mount{}, func() {}, fmt.Errorf("failed to create tempdir: %v", err) + } + dir = strings.TrimSuffix(dir, "\n") + + out, err := machine.RunCommand("chmod", "777", dir) + if err != nil { + machine.RunCommand("rm", "-rf", dir) + return mount.Mount{}, func() {}, fmt.Errorf("failed modify directory: %v %s", err, out) + } + return mount.Mount{ + Target: target, + Source: dir, + Type: mount.TypeBind, + }, func() { machine.RunCommand("rm", "-rf", dir) }, nil + case mount.TypeTmpfs: + return mount.Mount{ + Target: target, + Type: mount.TypeTmpfs, + }, func() {}, nil + default: + return mount.Mount{}, func() {}, fmt.Errorf("illegal mount time not supported: %v", mountType) + } +} diff --git a/test/benchmarks/fs/fs.go b/test/benchmarks/fs/fs.go new file mode 100644 index 000000000..e5ca28c3b --- /dev/null +++ b/test/benchmarks/fs/fs.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fs holds benchmarks around filesystem performance. +package fs + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package fs. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD new file mode 100644 index 000000000..c2e316709 --- /dev/null +++ b/test/benchmarks/harness/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "harness", + testonly = 1, + srcs = [ + "harness.go", + "machine.go", + "util.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) diff --git a/test/benchmarks/harness/harness.go b/test/benchmarks/harness/harness.go new file mode 100644 index 000000000..68bd7b4cf --- /dev/null +++ b/test/benchmarks/harness/harness.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package harness holds utility code for running benchmarks on Docker. +package harness + +import ( + "flag" + + "gvisor.dev/gvisor/pkg/test/dockerutil" +) + +// Harness is a handle for managing state in benchmark runs. +type Harness struct { +} + +// Init performs any harness initilialization before runs. +func (h *Harness) Init() error { + flag.Parse() + dockerutil.EnsureSupportedDockerVersion() + return nil +} + +// GetMachine returns this run's implementation of machine. +func (h *Harness) GetMachine() (Machine, error) { + return &localMachine{}, nil +} diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go new file mode 100644 index 000000000..88e5e841b --- /dev/null +++ b/test/benchmarks/harness/machine.go @@ -0,0 +1,81 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package harness + +import ( + "context" + "net" + "os/exec" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Machine describes a real machine for use in benchmarks. +type Machine interface { + // GetContainer gets a container from the machine. The container uses the + // runtime under test and is profiled if requested by flags. + GetContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + + // GetNativeContainer gets a native container from the machine. Native containers + // use runc by default and are not profiled. + GetNativeContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + + // RunCommand runs cmd on this machine. + RunCommand(cmd string, args ...string) (string, error) + + // Returns IP Address for the machine. + IPAddress() (net.IP, error) + + // CleanUp cleans up this machine. + CleanUp() +} + +// localMachine describes this machine. +type localMachine struct { +} + +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeContainer(ctx, logger) +} + +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetNativeContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeNativeContainer(ctx, logger) +} + +// RunCommand implements Machine.RunCommand for localMachine. +func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) { + c := exec.Command(cmd, args...) + out, err := c.CombinedOutput() + return string(out), err +} + +// IPAddress implements Machine.IPAddress. +func (l *localMachine) IPAddress() (net.IP, error) { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + return nil, err + } + defer conn.Close() + + addr := conn.LocalAddr().(*net.UDPAddr) + return addr.IP, nil +} + +// CleanUp implements Machine.CleanUp and does nothing for localMachine. +func (*localMachine) CleanUp() { +} diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go new file mode 100644 index 000000000..86b863f78 --- /dev/null +++ b/test/benchmarks/harness/util.go @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package harness + +import ( + "context" + "fmt" + "net" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +//TODO(gvisor.dev/issue/3535): move to own package or move methods to harness struct. + +// WaitUntilServing grabs a container from `machine` and waits for a server at +// IP:port. +func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error { + var logger testutil.DefaultLogger = "util" + netcat := machine.GetNativeContainer(ctx, logger) + defer netcat.CleanUp(ctx) + + cmd := fmt.Sprintf("while ! wget -q --spider http://%s:%d; do true; done", server, port) + _, err := netcat.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/util", + }, "sh", "-c", cmd) + return err +} + +// DropCaches drops caches on the provided machine. Requires root. +func DropCaches(machine Machine) error { + if out, err := machine.RunCommand("/bin/sh", "-c", "sync && sysctl vm.drop_caches=3"); err != nil { + return fmt.Errorf("failed to drop caches: %v logs: %s", err, out) + } + return nil +} diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD new file mode 100644 index 000000000..bb242d385 --- /dev/null +++ b/test/benchmarks/media/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "media", + testonly = 1, + srcs = ["media.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "media_test", + size = "large", + srcs = ["ffmpeg_test.go"], + library = ":media", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go new file mode 100644 index 000000000..7822dfad7 --- /dev/null +++ b/test/benchmarks/media/ffmpeg_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package media + +import ( + "context" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkFfmpeg runs ffmpeg in a container and records runtime. +// BenchmarkFfmpeg should run as root to drop caches. +func BenchmarkFfmpeg(b *testing.B) { + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ffmpeg", + }, cmd...); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go new file mode 100644 index 000000000..c7b35b758 --- /dev/null +++ b/test/benchmarks/media/media.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package media holds benchmarks around media processing applications. +package media + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package media. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD new file mode 100644 index 000000000..970f52706 --- /dev/null +++ b/test/benchmarks/ml/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ml", + testonly = 1, + srcs = ["ml.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "ml_test", + size = "large", + srcs = ["tensorflow_test.go"], + library = ":ml", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go new file mode 100644 index 000000000..13282d7bb --- /dev/null +++ b/test/benchmarks/ml/ml.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ml holds benchmarks around machine learning performance. +package ml + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package ml. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go new file mode 100644 index 000000000..f7746897d --- /dev/null +++ b/test/benchmarks/ml/tensorflow_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package ml + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkTensorflow runs workloads from a TensorFlow tutorial. +// See: https://github.com/aymericdamien/TensorFlow-Examples +func BenchmarkTensorflow(b *testing.B) { + workloads := map[string]string{ + "GradientDecisionTree": "2_BasicModels/gradient_boosted_decision_tree.py", + "Kmeans": "2_BasicModels/kmeans.py", + "LogisticRegression": "2_BasicModels/logistic_regression.py", + "NearestNeighbor": "2_BasicModels/nearest_neighbor.py", + "RandomForest": "2_BasicModels/random_forest.py", + "ConvolutionalNetwork": "3_NeuralNetworks/convolutional_network.py", + "MultilayerPerceptron": "3_NeuralNetworks/multilayer_perceptron.py", + "NeuralNetwork": "3_NeuralNetworks/neural_network.py", + } + + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + for name, workload := range workloads { + b.Run(name, func(b *testing.B) { + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if out, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/tensorflow", + Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"}, + WorkDir: "/TensorFlow-Examples/examples", + }, "python", workload); err != nil { + b.Fatalf("failed to run container: %v logs: %s", err, out) + } + } + }) + } + +} diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD new file mode 100644 index 000000000..472b5c387 --- /dev/null +++ b/test/benchmarks/network/BUILD @@ -0,0 +1,42 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "network", + testonly = 1, + srcs = [ + "network.go", + "static_server.go", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) + +go_test( + name = "network_test", + size = "large", + srcs = [ + "httpd_test.go", + "iperf_test.go", + "nginx_test.go", + "node_test.go", + "ruby_test.go", + ], + library = ":network", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go new file mode 100644 index 000000000..369ab326e --- /dev/null +++ b/test/benchmarks/network/httpd_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// see Dockerfile '//images/benchmarks/httpd'. +var httpdDocs = map[string]string{ + "notfound": "notfound", + "1Kb": "latin1k.txt", + "10Kb": "latin10k.txt", + "100Kb": "latin100k.txt", + "1Mb": "latin1024k.txt", + "10Mb": "latin10240k.txt", +} + +// BenchmarkHttpdConcurrency iterates the concurrency argument and tests +// how well the runtime under test handles requests in parallel. +func BenchmarkHttpdConcurrency(b *testing.B) { + // The test iterates over client concurrency, so set other parameters. + concurrency := []int{1, 25, 50, 100, 1000} + + for _, c := range concurrency { + b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: httpdDocs["10Kb"], + } + runHttpd(b, hey, false /* reverse */) + }) + } +} + +// BenchmarkHttpdDocSize iterates over different sized payloads, testing how +// well the runtime handles sending different payload sizes. +func BenchmarkHttpdDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, false /* reverse */) +} + +// BenchmarkReverseHttpdDocSize iterates over different sized payloads, testing +// how well the runtime handles receiving different payload sizes. +func BenchmarkReverseHttpdDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, true /* reverse */) +} + +// benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks +// for each size. +func benchmarkHttpdDocSize(b *testing.B, reverse bool) { + b.Helper() + for name, filename := range httpdDocs { + concurrency := []int{1, 25, 50, 100, 1000} + for _, c := range concurrency { + b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: filename, + } + runHttpd(b, hey, reverse) + }) + } + } +} + +// runHttpd configures the static serving methods to run httpd. +func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) { + // httpd runs on port 80. + port := 80 + httpdRunOpts := dockerutil.RunOpts{ + Image: "benchmarks/httpd", + Ports: []int{port}, + Env: []string{ + // Standard environmental variables for httpd. + "APACHE_RUN_DIR=/tmp", + "APACHE_RUN_USER=nobody", + "APACHE_RUN_GROUP=nogroup", + "APACHE_LOG_DIR=/tmp", + "APACHE_PID_FILE=/tmp/apache.pid", + }, + } + httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"} + runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse) +} diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go new file mode 100644 index 000000000..b8ab7dfb8 --- /dev/null +++ b/test/benchmarks/network/iperf_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +func BenchmarkIperf(b *testing.B) { + iperf := tools.Iperf{ + Time: 10, // time in seconds to run client. + } + + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + ctx := context.Background() + for _, bm := range []struct { + name string + clientFunc func(context.Context, testutil.Logger) *dockerutil.Container + serverFunc func(context.Context, testutil.Logger) *dockerutil.Container + }{ + // We are either measuring the server or the client. The other should be + // runc. e.g. Upload sees how fast the runtime under test uploads to a native + // server. + { + name: "Upload", + clientFunc: clientMachine.GetContainer, + serverFunc: serverMachine.GetNativeContainer, + }, + { + name: "Download", + clientFunc: clientMachine.GetNativeContainer, + serverFunc: serverMachine.GetContainer, + }, + } { + b.Run(bm.name, func(b *testing.B) { + // Set up the containers. + server := bm.serverFunc(ctx, b) + defer server.CleanUp(ctx) + client := bm.clientFunc(ctx, b) + defer client.CleanUp(ctx) + + // iperf serves on port 5001 by default. + port := 5001 + + // Start the server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + Ports: []int{port}, + }, "iperf", "-s"); err != nil { + b.Fatalf("failed to start server with: %v", err) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find port %d: %v", port, err) + } + + // Make sure the server is up and serving before we run. + if err := harness.WaitUntilServing(ctx, clientMachine, ip, servingPort); err != nil { + b.Fatalf("failed to wait for server: %v", err) + } + // Run the client. + b.ResetTimer() + + // Restart the server profiles. If the server isn't being profiled + // this does nothing. + server.RestartProfiles() + for i := 0; i < b.N; i++ { + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + }, iperf.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("failed to run client: %v", err) + } + b.StopTimer() + iperf.Report(b, out) + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go new file mode 100644 index 000000000..ce17ddb94 --- /dev/null +++ b/test/benchmarks/network/network.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package network holds benchmarks around raw network performance. +package network + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package network. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go new file mode 100644 index 000000000..9ec70369b --- /dev/null +++ b/test/benchmarks/network/nginx_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// see Dockerfile '//images/benchmarks/nginx'. +var nginxDocs = map[string]string{ + "notfound": "notfound", + "1Kb": "latin1k.txt", + "10Kb": "latin10k.txt", + "100Kb": "latin100k.txt", + "1Mb": "latin1024k.txt", + "10Mb": "latin10240k.txt", +} + +// BenchmarkNginxConcurrency iterates the concurrency argument and tests +// how well the runtime under test handles requests in parallel. +func BenchmarkNginxConcurrency(b *testing.B) { + concurrency := []int{1, 25, 100, 1000} + for _, c := range concurrency { + for _, tmpfs := range []bool{true, false} { + fs := "Gofer" + if tmpfs { + fs = "Tmpfs" + } + name := fmt.Sprintf("%d_%s", c, fs) + b.Run(name, func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: nginxDocs["10kb"], // see Dockerfile '//images/benchmarks/nginx' and httpd_test. + } + runNginx(b, hey, false /* reverse */, tmpfs /* tmpfs */) + }) + } + + } +} + +// BenchmarkNginxDocSize iterates over different sized payloads, testing how +// well the runtime handles sending different payload sizes. +func BenchmarkNginxDocSize(b *testing.B) { + benchmarkNginxDocSize(b, false /* reverse */, true /* tmpfs */) + benchmarkNginxDocSize(b, false /* reverse */, false /* tmpfs */) +} + +// BenchmarkReverseNginxDocSize iterates over different sized payloads, testing +// how well the runtime handles receiving different payload sizes. +func BenchmarkReverseNginxDocSize(b *testing.B) { + benchmarkNginxDocSize(b, true /* reverse */, true /* tmpfs */) +} + +// benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks +// for each size. +func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) { + for name, filename := range nginxDocs { + concurrency := []int{1, 25, 50, 100, 1000} + for _, c := range concurrency { + fs := "Gofer" + if tmpfs { + fs = "Tmpfs" + } + benchName := fmt.Sprintf("%s_%d_%s", name, c, fs) + b.Run(benchName, func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: filename, + } + runNginx(b, hey, reverse, tmpfs) + }) + } + } +} + +// runNginx configures the static serving methods to run httpd. +func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) { + // nginx runs on port 80. + port := 80 + nginxRunOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + Ports: []int{port}, + } + + nginxCmd := []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"} + if tmpfs { + nginxCmd = []string{"sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && nginx -c /etc/nginx/nginx.conf"} + } + + // Command copies nginxDocs to tmpfs serving directory and runs nginx. + runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse) +} diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go new file mode 100644 index 000000000..0f4a205b6 --- /dev/null +++ b/test/benchmarks/network/node_test.go @@ -0,0 +1,127 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkNode runs requests using 'hey' against a Node server run on +// 'runtime'. The server responds to requests by grabbing some data in a +// redis instance and returns the data in its reponse. The test loops through +// increasing amounts of concurency for requests. +func BenchmarkNode(b *testing.B) { + concurrency := []int{1, 5, 10, 25} + for _, c := range concurrency { + b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N * c, // Requests b.N requests per thread. + Concurrency: c, + } + runNode(b, hey) + }) + } +} + +// runNode runs the test for a given # of requests and concurrency. +func runNode(b *testing.B, hey *tools.Hey) { + b.Helper() + + // The machine to hold Redis and the Node Server. + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer serverMachine.CleanUp() + + // The machine to run 'hey'. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer clientMachine.CleanUp() + + ctx := context.Background() + + // Spawn a redis instance for the app to use. + redis := serverMachine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + b.Fatalf("failed to spwan redis instance: %v", err) + } + defer redis.CleanUp(ctx) + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + b.Fatalf("failed to get IP from redis instance: %v", err) + } + + // Node runs on port 8080. + port := 8080 + + // Start-up the Node server. + nodeApp := serverMachine.GetContainer(ctx, b) + if err := nodeApp.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + Ports: []int{port}, + }, "node", "index.js", redisIP.String()); err != nil { + b.Fatalf("failed to spawn node instance: %v", err) + } + defer nodeApp.CleanUp(ctx) + + servingIP, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get ip from server: %v", err) + } + + servingPort, err := nodeApp.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to port from node instance: %v", err) + } + + // Wait until the Client sees the server as up. + harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort) + + heyCmd := hey.MakeCmd(servingIP, servingPort) + + nodeApp.RestartProfiles() + b.ResetTimer() + + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) + } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go new file mode 100644 index 000000000..67f63f76a --- /dev/null +++ b/test/benchmarks/network/ruby_test.go @@ -0,0 +1,134 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkRuby runs requests using 'hey' against a ruby application server. +// On start, ruby app generates some random data and pushes it to a redis +// instance. On a request, the app grabs for random entries from the redis +// server, publishes it to a document, and returns the doc to the request. +func BenchmarkRuby(b *testing.B) { + concurrency := []int{1, 5, 10, 25} + for _, c := range concurrency { + b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N * c, // b.N requests per thread. + Concurrency: c, + } + runRuby(b, hey) + }) + } +} + +// runRuby runs the test for a given # of requests and concurrency. +func runRuby(b *testing.B, hey *tools.Hey) { + b.Helper() + // The machine to hold Redis and the Ruby Server. + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer serverMachine.CleanUp() + + // The machine to run 'hey'. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer clientMachine.CleanUp() + ctx := context.Background() + + // Spawn a redis instance for the app to use. + redis := serverMachine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + b.Fatalf("failed to spwan redis instance: %v", err) + } + defer redis.CleanUp(ctx) + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + b.Fatalf("failed to get IP from redis instance: %v", err) + } + + // Ruby runs on port 9292. + const port = 9292 + + // Start-up the Ruby server. + rubyApp := serverMachine.GetContainer(ctx, b) + if err := rubyApp.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ruby", + WorkDir: "/app", + Links: []string{redis.MakeLink("redis")}, + Ports: []int{port}, + Env: []string{ + fmt.Sprintf("PORT=%d", port), + "WEB_CONCURRENCY=20", + "WEB_MAX_THREADS=20", + "RACK_ENV=production", + fmt.Sprintf("HOST=%s", redisIP), + }, + User: "nobody", + }, "sh", "-c", "/usr/bin/puma"); err != nil { + b.Fatalf("failed to spawn node instance: %v", err) + } + defer rubyApp.CleanUp(ctx) + + servingIP, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get ip from server: %v", err) + } + + servingPort, err := rubyApp.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to port from node instance: %v", err) + } + + // Wait until the Client sees the server as up. + if err := harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort); err != nil { + b.Fatalf("failed to wait until serving: %v", err) + } + heyCmd := hey.MakeCmd(servingIP, servingPort) + rubyApp.RestartProfiles() + b.ResetTimer() + + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) + } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go new file mode 100644 index 000000000..e747a1395 --- /dev/null +++ b/test/benchmarks/network/static_server.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// runStaticServer runs static serving workloads (httpd, nginx). +func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) { + ctx := context.Background() + + // Get two machines: a client and server. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Make the containers. 'reverse=true' specifies that the client should use the + // runtime under test. + var client, server *dockerutil.Container + if reverse { + client = clientMachine.GetContainer(ctx, b) + server = serverMachine.GetNativeContainer(ctx, b) + } else { + client = clientMachine.GetNativeContainer(ctx, b) + server = serverMachine.GetContainer(ctx, b) + } + defer client.CleanUp(ctx) + defer server.CleanUp(ctx) + + // Start the server. + if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil { + b.Fatalf("failed to start server: %v", err) + } + + // Get its IP. + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + // Get the published port. + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Make sure the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + b.ResetTimer() + server.RestartProfiles() + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, hey.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/benchmarks/tcp/BUILD b/test/benchmarks/tcp/BUILD index 6dde7d9e6..6dde7d9e6 100644 --- a/benchmarks/tcp/BUILD +++ b/test/benchmarks/tcp/BUILD diff --git a/benchmarks/tcp/README.md b/test/benchmarks/tcp/README.md index 38e6e69f0..38e6e69f0 100644 --- a/benchmarks/tcp/README.md +++ b/test/benchmarks/tcp/README.md diff --git a/benchmarks/tcp/nsjoin.c b/test/benchmarks/tcp/nsjoin.c index 524b4d549..524b4d549 100644 --- a/benchmarks/tcp/nsjoin.c +++ b/test/benchmarks/tcp/nsjoin.c diff --git a/benchmarks/tcp/tcp_benchmark.sh b/test/benchmarks/tcp/tcp_benchmark.sh index ef04b4ace..ef04b4ace 100755 --- a/benchmarks/tcp/tcp_benchmark.sh +++ b/test/benchmarks/tcp/tcp_benchmark.sh diff --git a/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go index 4b7ca7a14..5afe10f69 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -174,8 +174,8 @@ func newNetstackImpl(mode string) (impl, error) { } // Create a new network stack. - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol} s := stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, @@ -228,19 +228,26 @@ func newNetstackImpl(mode string) (impl, error) { }) // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) + { + opt := tcpip.TCPSACKEnabled(*sack) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Enable Receive Buffer Auto-Tuning. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(*moderateRecvBuf) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Set Congestion Control to cubic if requested. if *cubic { - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err) + opt := tcpip.CongestionControlOption("cubic") + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/test/benchmarks/tools/BUILD b/test/benchmarks/tools/BUILD new file mode 100644 index 000000000..e5734d85c --- /dev/null +++ b/test/benchmarks/tools/BUILD @@ -0,0 +1,33 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "tools", + srcs = [ + "ab.go", + "fio.go", + "hey.go", + "iperf.go", + "meminfo.go", + "redis.go", + "sysbench.go", + "tools.go", + ], + visibility = ["//:sandbox"], +) + +go_test( + name = "tools_test", + size = "small", + srcs = [ + "ab_test.go", + "fio_test.go", + "hey_test.go", + "iperf_test.go", + "meminfo_test.go", + "redis_test.go", + "sysbench_test.go", + ], + library = ":tools", +) diff --git a/test/benchmarks/tools/ab.go b/test/benchmarks/tools/ab.go new file mode 100644 index 000000000..4cc9c3bce --- /dev/null +++ b/test/benchmarks/tools/ab.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "testing" +) + +// ApacheBench is for the client application ApacheBench. +type ApacheBench struct { + Requests int + Concurrency int + Doc string + // TODO(zkoopmans): support KeepAlive and pass option to enable. +} + +// MakeCmd makes an ApacheBench command. +func (a *ApacheBench) MakeCmd(ip net.IP, port int) []string { + path := fmt.Sprintf("http://%s:%d/%s", ip, port, a.Doc) + // See apachebench (ab) for flags. + cmd := fmt.Sprintf("ab -n %d -c %d %s", a.Requests, a.Concurrency, path) + return []string{"sh", "-c", cmd} +} + +// Report parses and reports metrics from ApacheBench output. +func (a *ApacheBench) Report(b *testing.B, output string) { + // Parse and report custom metrics. + transferRate, err := a.parseTransferRate(output) + if err != nil { + b.Logf("failed to parse transferrate: %v", err) + } + b.ReportMetric(transferRate*1024, "transfer_rate_b/s") // Convert from Kb/s to b/s. + + latency, err := a.parseLatency(output) + if err != nil { + b.Logf("failed to parse latency: %v", err) + } + b.ReportMetric(latency/1000, "mean_latency_secs") // Convert from ms to s. + + reqPerSecond, err := a.parseRequestsPerSecond(output) + if err != nil { + b.Logf("failed to parse requests per second: %v", err) + } + b.ReportMetric(reqPerSecond, "requests_per_second") +} + +var transferRateRE = regexp.MustCompile(`Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received`) + +// parseTransferRate parses transfer rate from ApacheBench output. +func (a *ApacheBench) parseTransferRate(data string) (float64, error) { + match := transferRateRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var latencyRE = regexp.MustCompile(`Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s`) + +// parseLatency parses latency from ApacheBench output. +func (a *ApacheBench) parseLatency(data string) (float64, error) { + match := latencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var requestsPerSecondRE = regexp.MustCompile(`Requests per second:\s+(\d+\.?\d+?)\s+`) + +// parseRequestsPerSecond parses requests per second from ApacheBench output. +func (a *ApacheBench) parseRequestsPerSecond(data string) (float64, error) { + match := requestsPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/ab_test.go b/test/benchmarks/tools/ab_test.go new file mode 100644 index 000000000..28ee66ec1 --- /dev/null +++ b/test/benchmarks/tools/ab_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import "testing" + +// TestApacheBench checks the ApacheBench parsers on sample output. +func TestApacheBench(t *testing.T) { + // Sample output from apachebench. + sampleData := `This is ApacheBench, Version 2.3 <$Revision: 1826891 $> +Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ +Licensed to The Apache Software Foundation, http://www.apache.org/ + +Benchmarking 10.10.10.10 (be patient).....done + + +Server Software: Apache/2.4.38 +Server Hostname: 10.10.10.10 +Server Port: 80 + +Document Path: /latin10k.txt +Document Length: 210 bytes + +Concurrency Level: 1 +Time taken for tests: 0.180 seconds +Complete requests: 100 +Failed requests: 0 +Non-2xx responses: 100 +Total transferred: 38800 bytes +HTML transferred: 21000 bytes +Requests per second: 556.44 [#/sec] (mean) +Time per request: 1.797 [ms] (mean) +Time per request: 1.797 [ms] (mean, across all concurrent requests) +Transfer rate: 210.84 [Kbytes/sec] received + +Connection Times (ms) + min mean[+/-sd] median max +Connect: 0 0 0.2 0 2 +Processing: 1 2 1.0 1 8 +Waiting: 1 1 1.0 1 7 +Total: 1 2 1.2 1 10 + +Percentage of the requests served within a certain time (ms) + 50% 1 + 66% 2 + 75% 2 + 80% 2 + 90% 2 + 95% 3 + 98% 7 + 99% 10 + 100% 10 (longest request)` + + ab := ApacheBench{} + want := 210.84 + got, err := ab.parseTransferRate(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseTransferRate got: %f, want: %f", got, want) + } + + want = 2.0 + got, err = ab.parseLatency(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseLatency got: %f, want: %f", got, want) + } + + want = 556.44 + got, err = ab.parseRequestsPerSecond(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseRequestsPerSecond got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/fio.go b/test/benchmarks/tools/fio.go new file mode 100644 index 000000000..20000db16 --- /dev/null +++ b/test/benchmarks/tools/fio.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + "testing" +) + +// Fio makes 'fio' commands and parses their output. +type Fio struct { + Test string // test to run: read, write, randread, randwrite. + Size string // total size to be read/written of format N[GMK] (e.g. 5G). + Blocksize string // blocksize to be read/write of format N[GMK] (e.g. 4K). + Iodepth int // iodepth for reads/writes. + Time int // time to run the test in seconds, usually for rand(read/write). +} + +// MakeCmd makes a 'fio' command. +func (f *Fio) MakeCmd(filename string) []string { + cmd := []string{"fio", "--output-format=json", "--ioengine=sync"} + cmd = append(cmd, fmt.Sprintf("--name=%s", f.Test)) + cmd = append(cmd, fmt.Sprintf("--size=%s", f.Size)) + cmd = append(cmd, fmt.Sprintf("--blocksize=%s", f.Blocksize)) + cmd = append(cmd, fmt.Sprintf("--filename=%s", filename)) + cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.Iodepth)) + cmd = append(cmd, fmt.Sprintf("--rw=%s", f.Test)) + if f.Time != 0 { + cmd = append(cmd, "--time_based") + cmd = append(cmd, fmt.Sprintf("--runtime=%d", f.Time)) + } + return cmd +} + +// Report reports metrics based on output from an 'fio' command. +func (f *Fio) Report(b *testing.B, output string) { + b.Helper() + // Parse the output and report the metrics. + isRead := strings.Contains(f.Test, "read") + bw, err := f.parseBandwidth(output, isRead) + if err != nil { + b.Fatalf("failed to parse bandwidth from %s with: %v", output, err) + } + b.ReportMetric(bw, "bandwidth_b/s") // in b/s. + + iops, err := f.parseIOps(output, isRead) + if err != nil { + b.Fatalf("failed to parse iops from %s with: %v", output, err) + } + b.ReportMetric(iops, "iops") +} + +// parseBandwidth reports the bandwidth in b/s. +func (f *Fio) parseBandwidth(data string, isRead bool) (float64, error) { + if isRead { + result, err := f.parseFioJSON(data, "read", "bw") + if err != nil { + return 0, err + } + return 1024 * result, nil + } + result, err := f.parseFioJSON(data, "write", "bw") + if err != nil { + return 0, err + } + return 1024 * result, nil +} + +// parseIOps reports the write IO per second metric. +func (f *Fio) parseIOps(data string, isRead bool) (float64, error) { + if isRead { + return f.parseFioJSON(data, "read", "iops") + } + return f.parseFioJSON(data, "write", "iops") +} + +// fioResult is for parsing FioJSON. +type fioResult struct { + Jobs []fioJob +} + +// fioJob is for parsing FioJSON. +type fioJob map[string]json.RawMessage + +// fioMetrics is for parsing FioJSON. +type fioMetrics map[string]json.RawMessage + +// parseFioJSON parses data and grabs "op" (read or write) and "metric" +// (bw or iops) from the JSON. +func (f *Fio) parseFioJSON(data, op, metric string) (float64, error) { + var result fioResult + if err := json.Unmarshal([]byte(data), &result); err != nil { + return 0, fmt.Errorf("could not unmarshal data: %v", err) + } + + if len(result.Jobs) < 1 { + return 0, fmt.Errorf("no jobs present to parse") + } + + var metrics fioMetrics + if err := json.Unmarshal(result.Jobs[0][op], &metrics); err != nil { + return 0, fmt.Errorf("could not unmarshal jobs: %v", err) + } + + if _, ok := metrics[metric]; !ok { + return 0, fmt.Errorf("no metric found for op: %s", op) + } + return strconv.ParseFloat(string(metrics[metric]), 64) +} diff --git a/test/benchmarks/tools/fio_test.go b/test/benchmarks/tools/fio_test.go new file mode 100644 index 000000000..a98277150 --- /dev/null +++ b/test/benchmarks/tools/fio_test.go @@ -0,0 +1,122 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import "testing" + +// TestFio checks the Fio parsers on sample output. +func TestFio(t *testing.T) { + sampleData := ` +{ + "fio version" : "fio-3.1", + "timestamp" : 1554837456, + "timestamp_ms" : 1554837456621, + "time" : "Tue Apr 9 19:17:36 2019", + "jobs" : [ + { + "jobname" : "test", + "groupid" : 0, + "error" : 0, + "eta" : 2147483647, + "elapsed" : 1, + "job options" : { + "name" : "test", + "ioengine" : "sync", + "size" : "1073741824", + "filename" : "/disk/file.dat", + "iodepth" : "4", + "bs" : "4096", + "rw" : "write" + }, + "read" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 123456, + "iops" : 1234.5678, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "write" : { + "io_bytes" : 1073741824, + "io_kbytes" : 1048576, + "bw" : 1753471, + "iops" : 438367.892977, + "runtime" : 598, + "total_ios" : 262144, + "bw_min" : 1731120, + "bw_max" : 1731120, + "bw_agg" : 98.725328, + "bw_mean" : 1731120.000000, + "bw_dev" : 0.000000, + "bw_samples" : 1, + "iops_min" : 432780, + "iops_max" : 432780, + "iops_mean" : 432780.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 1 + } + } + ] +} +` + fio := Fio{} + // WriteBandwidth. + got, err := fio.parseBandwidth(sampleData, false) + var want float64 = 1753471.0 * 1024 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // ReadBandwidth. + got, err = fio.parseBandwidth(sampleData, true) + want = 123456 * 1024 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // WriteIOps. + got, err = fio.parseIOps(sampleData, false) + want = 438367.892977 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // ReadIOps. + got, err = fio.parseIOps(sampleData, true) + want = 1234.5678 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go new file mode 100644 index 000000000..b1e20e356 --- /dev/null +++ b/test/benchmarks/tools/hey.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "testing" +) + +// Hey is for the client application 'hey'. +type Hey struct { + Requests int // Note: requests cannot be less than concurrency. + Concurrency int + Doc string +} + +// MakeCmd returns a 'hey' command. +func (h *Hey) MakeCmd(ip net.IP, port int) []string { + return strings.Split(fmt.Sprintf("hey -n %d -c %d http://%s:%d/%s", + h.Requests, h.Concurrency, ip, port, h.Doc), " ") +} + +// Report parses output from 'hey' and reports metrics. +func (h *Hey) Report(b *testing.B, output string) { + b.Helper() + requests, err := h.parseRequestsPerSecond(output) + if err != nil { + b.Fatalf("failed to parse requests per second: %v", err) + } + b.ReportMetric(requests, "requests_per_second") + + ave, err := h.parseAverageLatency(output) + if err != nil { + b.Fatalf("failed to parse average latency: %v", err) + } + b.ReportMetric(ave, "average_latency_secs") +} + +var heyReqPerSecondRE = regexp.MustCompile(`Requests/sec:\s*(\d+\.?\d+?)\s+`) + +// parseRequestsPerSecond finds requests per second from 'hey' output. +func (h *Hey) parseRequestsPerSecond(data string) (float64, error) { + match := heyReqPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var heyAverageLatencyRE = regexp.MustCompile(`Average:\s*(\d+\.?\d+?)\s+secs`) + +// parseHeyAverageLatency finds Average Latency in seconds form 'hey' output. +func (h *Hey) parseAverageLatency(data string) (float64, error) { + match := heyAverageLatencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get average latency match%d : %s", len(match), data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/hey_test.go b/test/benchmarks/tools/hey_test.go new file mode 100644 index 000000000..e0cab1f52 --- /dev/null +++ b/test/benchmarks/tools/hey_test.go @@ -0,0 +1,81 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import "testing" + +// TestHey checks the Hey parsers on sample output. +func TestHey(t *testing.T) { + sampleData := ` + Summary: + Total: 2.2391 secs + Slowest: 1.6292 secs + Fastest: 0.0066 secs + Average: 0.5351 secs + Requests/sec: 89.3202 + + Total data: 841200 bytes + Size/request: 4206 bytes + + Response time histogram: + 0.007 [1] | + 0.169 [0] | + 0.331 [149] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ + 0.493 [0] | + 0.656 [0] | + 0.818 [0] | + 0.980 [0] | + 1.142 [0] | + 1.305 [0] | + 1.467 [49] |■■■■■■■■■■■■■ + 1.629 [1] | + + + Latency distribution: + 10% in 0.2149 secs + 25% in 0.2449 secs + 50% in 0.2703 secs + 75% in 1.3315 secs + 90% in 1.4045 secs + 95% in 1.4232 secs + 99% in 1.4362 secs + + Details (average, fastest, slowest): + DNS+dialup: 0.0002 secs, 0.0066 secs, 1.6292 secs + DNS-lookup: 0.0000 secs, 0.0000 secs, 0.0000 secs + req write: 0.0000 secs, 0.0000 secs, 0.0012 secs + resp wait: 0.5225 secs, 0.0064 secs, 1.4346 secs + resp read: 0.0122 secs, 0.0001 secs, 0.2006 secs + + Status code distribution: + [200] 200 responses + ` + hey := Hey{} + want := 89.3202 + got, err := hey.parseRequestsPerSecond(sampleData) + if err != nil { + t.Fatalf("failed to parse request per second with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + want = 0.5351 + got, err = hey.parseAverageLatency(sampleData) + if err != nil { + t.Fatalf("failed to parse average latency with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go new file mode 100644 index 000000000..df3d9349b --- /dev/null +++ b/test/benchmarks/tools/iperf.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "testing" +) + +// Iperf is for the client side of `iperf`. +type Iperf struct { + Time int +} + +// MakeCmd returns a iperf client command. +func (i *Iperf) MakeCmd(ip net.IP, port int) []string { + // iperf report in Kb realtime + return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", i.Time, ip, port), " ") +} + +// Report parses output from iperf client and reports metrics. +func (i *Iperf) Report(b *testing.B, output string) { + b.Helper() + // Parse bandwidth and report it. + bW, err := i.bandwidth(output) + if err != nil { + b.Fatalf("failed to parse bandwitdth from %s: %v", output, err) + } + b.ReportMetric(bW*1024, "bandwidth_b/s") // Convert from Kb/s to b/s. +} + +// bandwidth parses the Bandwidth number from an iperf report. A sample is below. +func (i *Iperf) bandwidth(data string) (float64, error) { + re := regexp.MustCompile(`\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec`) + match := re.FindStringSubmatch(data) + if len(match) < 1 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/iperf_test.go b/test/benchmarks/tools/iperf_test.go new file mode 100644 index 000000000..03bb30d05 --- /dev/null +++ b/test/benchmarks/tools/iperf_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package tools + +import "testing" + +// TestIperf checks the Iperf parsers on sample output. +func TestIperf(t *testing.T) { + sampleData := ` +------------------------------------------------------------ +Client connecting to 10.138.15.215, TCP port 32779 +TCP window size: 45.0 KByte (default) +------------------------------------------------------------ +[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 +[ ID] Interval Transfer Bandwidth +[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec +` + i := Iperf{} + bandwidth, err := i.bandwidth(sampleData) + if err != nil || bandwidth != 45900 { + t.Fatalf("failed with: %v and %f", err, bandwidth) + } +} diff --git a/test/benchmarks/tools/meminfo.go b/test/benchmarks/tools/meminfo.go new file mode 100644 index 000000000..2414a96a7 --- /dev/null +++ b/test/benchmarks/tools/meminfo.go @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "testing" +) + +// Meminfo wraps measurements of MemAvailable using /proc/meminfo. +type Meminfo struct { +} + +// MakeCmd returns a command for checking meminfo. +func (*Meminfo) MakeCmd() (string, []string) { + return "cat", []string{"/proc/meminfo"} +} + +// Report takes two reads of meminfo, parses them, and reports the difference +// divided by b.N. +func (*Meminfo) Report(b *testing.B, before, after string) { + b.Helper() + + beforeVal, err := parseMemAvailable(before) + if err != nil { + b.Fatalf("could not parse before value %s: %v", before, err) + } + + afterVal, err := parseMemAvailable(after) + if err != nil { + b.Fatalf("could not parse before value %s: %v", before, err) + } + val := 1024 * ((beforeVal - afterVal) / float64(b.N)) + b.ReportMetric(val, "average_container_size_bytes") +} + +var memInfoRE = regexp.MustCompile(`MemAvailable:\s*(\d+)\skB\n`) + +// parseMemAvailable grabs the MemAvailable number from /proc/meminfo. +func parseMemAvailable(data string) (float64, error) { + match := memInfoRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("couldn't find MemAvailable in %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/meminfo_test.go b/test/benchmarks/tools/meminfo_test.go new file mode 100644 index 000000000..ba803540f --- /dev/null +++ b/test/benchmarks/tools/meminfo_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestMeminfo checks the Meminfo parser on sample output. +func TestMeminfo(t *testing.T) { + sampleData := ` +MemTotal: 16337408 kB +MemFree: 3742696 kB +MemAvailable: 9319948 kB +Buffers: 1433884 kB +Cached: 4607036 kB +SwapCached: 45284 kB +Active: 8288376 kB +Inactive: 2685928 kB +Active(anon): 4724912 kB +Inactive(anon): 1047940 kB +Active(file): 3563464 kB +Inactive(file): 1637988 kB +Unevictable: 326940 kB +Mlocked: 48 kB +SwapTotal: 33292284 kB +SwapFree: 32865736 kB +Dirty: 708 kB +Writeback: 0 kB +AnonPages: 4304204 kB +Mapped: 975424 kB +Shmem: 910292 kB +KReclaimable: 744532 kB +Slab: 1058448 kB +SReclaimable: 744532 kB +SUnreclaim: 313916 kB +KernelStack: 25188 kB +PageTables: 65300 kB +NFS_Unstable: 0 kB +Bounce: 0 kB +WritebackTmp: 0 kB +CommitLimit: 41460988 kB +Committed_AS: 22859492 kB +VmallocTotal: 34359738367 kB +VmallocUsed: 63088 kB +VmallocChunk: 0 kB +Percpu: 9248 kB +HardwareCorrupted: 0 kB +AnonHugePages: 786432 kB +ShmemHugePages: 0 kB +ShmemPmdMapped: 0 kB +FileHugePages: 0 kB +FilePmdMapped: 0 kB +HugePages_Total: 0 +HugePages_Free: 0 +HugePages_Rsvd: 0 +HugePages_Surp: 0 +Hugepagesize: 2048 kB +Hugetlb: 0 kB +DirectMap4k: 5408532 kB +DirectMap2M: 11241472 kB +DirectMap1G: 1048576 kB +` + want := 9319948.0 + got, err := parseMemAvailable(sampleData) + if err != nil { + t.Fatalf("parseMemAvailable failed: %v", err) + } + if got != want { + t.Fatalf("parseMemAvailable got %f, want %f", got, want) + } +} diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go new file mode 100644 index 000000000..c899ae0d4 --- /dev/null +++ b/test/benchmarks/tools/redis.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "testing" +) + +// Redis is for the client 'redis-benchmark'. +type Redis struct { + Operation string +} + +// MakeCmd returns a redis-benchmark client command. +func (r *Redis) MakeCmd(ip net.IP, port int) []string { + // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case. + // Note that "ping" will run both PING_INLINE and PING_BULK. + if r.Operation == "PING_BULK" { + return strings.Split( + fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, port), " ") + } + + // runs redis-benchmark -t operation for 100K requests against server. + return strings.Split( + fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", r.Operation, ip, port), " ") +} + +// Report parses output from redis-benchmark client and reports metrics. +func (r *Redis) Report(b *testing.B, output string) { + b.Helper() + result, err := r.parseOperation(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, r.Operation) // operations per second +} + +// parseOperation grabs the metric operations per second from redis-benchmark output. +func (r *Redis) parseOperation(data string) (float64, error) { + re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, r.Operation)) + match := re.FindStringSubmatch(data) + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find %s in %s", r.Operation, data) + } + return strconv.ParseFloat(match[2], 64) +} diff --git a/test/benchmarks/tools/redis_test.go b/test/benchmarks/tools/redis_test.go new file mode 100644 index 000000000..4bafda66f --- /dev/null +++ b/test/benchmarks/tools/redis_test.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestRedis checks the Redis parsers on sample output. +func TestRedis(t *testing.T) { + sampleData := ` + "PING_INLINE","48661.80" + "PING_BULK","50301.81" + "SET","48923.68" + "GET","49382.71" + "INCR","49975.02" + "LPUSH","49875.31" + "RPUSH","50276.52" + "LPOP","50327.12" + "RPOP","50556.12" + "SADD","49504.95" + "HSET","49504.95" + "SPOP","50025.02" + "LPUSH (needed to benchmark LRANGE)","48875.86" + "LRANGE_100 (first 100 elements)","33955.86" + "LRANGE_300 (first 300 elements)","16550.81"// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + + "LRANGE_500 (first 450 elements)","13653.74" + "LRANGE_600 (first 600 elements)","11219.57" + "MSET (10 keys)","44682.75" + ` + wants := map[string]float64{ + "PING_INLINE": 48661.80, + "PING_BULK": 50301.81, + "SET": 48923.68, + "GET": 49382.71, + "INCR": 49975.02, + "LPUSH": 49875.31, + "RPUSH": 50276.52, + "LPOP": 50327.12, + "RPOP": 50556.12, + "SADD": 49504.95, + "HSET": 49504.95, + "SPOP": 50025.02, + "LRANGE_100": 33955.86, + "LRANGE_300": 16550.81, + "LRANGE_500": 13653.74, + "LRANGE_600": 11219.57, + "MSET": 44682.75, + } + for op, want := range wants { + redis := Redis{ + Operation: op, + } + if got, err := redis.parseOperation(sampleData); err != nil { + t.Fatalf("failed to parse %s: %v", op, err) + } else if want != got { + t.Fatalf("wanted %f for op %s, got %f", want, op, got) + } + } +} diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go new file mode 100644 index 000000000..6b2f75ca2 --- /dev/null +++ b/test/benchmarks/tools/sysbench.go @@ -0,0 +1,245 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "testing" +) + +var warmup = "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null &&" + +// Sysbench represents a 'sysbench' command. +type Sysbench interface { + MakeCmd() []string // Makes a sysbench command. + flags() []string + Report(*testing.B, string) // Reports results contained in string. +} + +// SysbenchBase is the top level struct for sysbench and holds top-level arguments +// for sysbench. See: 'sysbench --help' +type SysbenchBase struct { + Threads int // number of Threads for the test. + Time int // time limit for test in seconds. +} + +// baseFlags returns top level flags. +func (s *SysbenchBase) baseFlags() []string { + var ret []string + if s.Threads > 0 { + ret = append(ret, fmt.Sprintf("--threads=%d", s.Threads)) + } + if s.Time > 0 { + ret = append(ret, fmt.Sprintf("--time=%d", s.Time)) + } + return ret +} + +// SysbenchCPU is for 'sysbench [flags] cpu run' and holds CPU specific arguments. +type SysbenchCPU struct { + Base SysbenchBase + MaxPrime int // upper limit for primes generator [10000]. +} + +// MakeCmd makes commands for SysbenchCPU. +func (s *SysbenchCPU) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "cpu run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchCPU cmds. +func (s *SysbenchCPU) flags() []string { + cmd := s.Base.baseFlags() + if s.MaxPrime > 0 { + return append(cmd, fmt.Sprintf("--cpu-max-prime=%d", s.MaxPrime)) + } + return cmd +} + +// Report reports the relevant metrics for SysbenchCPU. +func (s *SysbenchCPU) Report(b *testing.B, output string) { + b.Helper() + result, err := s.parseEvents(output) + if err != nil { + b.Fatalf("parsing CPU events from %s failed: %v", output, err) + } + b.ReportMetric(result, "cpu_events_per_second") +} + +var cpuEventsPerSecondRE = regexp.MustCompile(`events per second:\s*(\d*.?\d*)\n`) + +// parseEvents parses cpu events per second. +func (s *SysbenchCPU) parseEvents(data string) (float64, error) { + match := cpuEventsPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find events per second: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// SysbenchMemory is for 'sysbench [FLAGS] memory run' and holds Memory specific arguments. +type SysbenchMemory struct { + Base SysbenchBase + BlockSize string // size of test memory block [1K]. + TotalSize string // size of data to transfer [100G]. + Scope string // memory access scope {global, local} [global]. + HugeTLB bool // allocate memory from HugeTLB [off]. + OperationType string // type of memory ops {read, write, none} [write]. + AccessMode string // access mode {seq, rnd} [seq]. +} + +// MakeCmd makes commands for SysbenchMemory. +func (s *SysbenchMemory) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "memory run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchMemory cmds. +func (s *SysbenchMemory) flags() []string { + cmd := s.Base.baseFlags() + if s.BlockSize != "" { + cmd = append(cmd, fmt.Sprintf("--memory-block-size=%s", s.BlockSize)) + } + if s.TotalSize != "" { + cmd = append(cmd, fmt.Sprintf("--memory-total-size=%s", s.TotalSize)) + } + if s.Scope != "" { + cmd = append(cmd, fmt.Sprintf("--memory-scope=%s", s.Scope)) + } + if s.HugeTLB { + cmd = append(cmd, "--memory-hugetlb=on") + } + if s.OperationType != "" { + cmd = append(cmd, fmt.Sprintf("--memory-oper=%s", s.OperationType)) + } + if s.AccessMode != "" { + cmd = append(cmd, fmt.Sprintf("--memory-access-mode=%s", s.AccessMode)) + } + return cmd +} + +// Report reports the relevant metrics for SysbenchMemory. +func (s *SysbenchMemory) Report(b *testing.B, output string) { + b.Helper() + result, err := s.parseOperations(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "operations_per_second") +} + +var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`) + +// parseOperations parses memory operations per second form sysbench memory ouput. +func (s *SysbenchMemory) parseOperations(data string) (float64, error) { + match := memoryOperationsRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("couldn't find memory operations per second: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// SysbenchMutex is for 'sysbench [FLAGS] mutex run' and holds Mutex specific arguments. +type SysbenchMutex struct { + Base SysbenchBase + Num int // total size of mutex array [4096]. + Locks int // number of mutex locks per thread [50K]. + Loops int // number of loops to do outside mutex lock [10K]. +} + +// MakeCmd makes commands for SysbenchMutex. +func (s *SysbenchMutex) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "mutex run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchMutex commands. +func (s *SysbenchMutex) flags() []string { + var cmd []string + cmd = append(cmd, s.Base.baseFlags()...) + if s.Num > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-num=%d", s.Num)) + } + if s.Locks > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-locks=%d", s.Locks)) + } + if s.Loops > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-loops=%d", s.Loops)) + } + return cmd +} + +// Report parses and reports relevant sysbench mutex metrics. +func (s *SysbenchMutex) Report(b *testing.B, output string) { + b.Helper() + + result, err := s.parseExecutionTime(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "average_execution_time_secs") + + result, err = s.parseDeviation(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "stdev_execution_time_secs") + + result, err = s.parseLatency(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result/1000, "average_latency_secs") +} + +var executionTimeRE = regexp.MustCompile(`execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)`) + +// parseExecutionTime parses threads fairness average execution time from sysbench output. +func (s *SysbenchMutex) parseExecutionTime(data string) (float64, error) { + match := executionTimeRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find execution time average: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// parseDeviation parses threads fairness stddev time from sysbench output. +func (s *SysbenchMutex) parseDeviation(data string) (float64, error) { + match := executionTimeRE.FindStringSubmatch(data) + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find execution time deviation: %s", data) + } + return strconv.ParseFloat(match[2], 64) +} + +var averageLatencyRE = regexp.MustCompile(`avg:[^\n^\d]*(\d*\.?\d*)`) + +// parseLatency parses latency from sysbench output. +func (s *SysbenchMutex) parseLatency(data string) (float64, error) { + match := averageLatencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find average latency: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/sysbench_test.go b/test/benchmarks/tools/sysbench_test.go new file mode 100644 index 000000000..850d1939e --- /dev/null +++ b/test/benchmarks/tools/sysbench_test.go @@ -0,0 +1,169 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestSysbenchCpu tests parses on sample 'sysbench cpu' output. +func TestSysbenchCpu(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Prime numbers limit: 10000 + +Initializing worker threads... + +Threads started! + +CPU speed: + events per second: 9093.38 + +General statistics: + total time: 10.0007s + total number of events: 90949 + +Latency (ms): + min: 0.64 + avg: 0.88 + max: 24.65 + 95th percentile: 1.55 + sum: 79936.91 + +Threads fairness: + events (avg/stddev): 11368.6250/831.38 + execution time (avg/stddev): 9.9921/0.01 +` + sysbench := SysbenchCPU{} + want := 9093.38 + if got, err := sysbench.parseEvents(sampleData); err != nil { + t.Fatalf("parse cpu events failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} + +// TestSysbenchMemory tests parsers on sample 'sysbench memory' output. +func TestSysbenchMemory(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Running memory speed test with the following options: + block size: 1KiB + total size: 102400MiB + operation: write + scope: global + +Initializing worker threads... + +Threads started! + +Total operations: 47999046 (9597428.64 per second) + +46874.07 MiB transferred (9372.49 MiB/sec) + + +General statistics: + total time: 5.0001s + total number of events: 47999046 + +Latency (ms): + min: 0.00 + avg: 0.00 + max: 0.21 + 95th percentile: 0.00 + sum: 33165.91 + +Threads fairness: + events (avg/stddev): 5999880.7500/111242.52 + execution time (avg/stddev): 4.1457/0.09 +` + sysbench := SysbenchMemory{} + want := 9597428.64 + if got, err := sysbench.parseOperations(sampleData); err != nil { + t.Fatalf("parse memory ops failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} + +// TestSysbenchMutex tests parsers on sample 'sysbench mutex' output. +func TestSysbenchMutex(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +The 'mutex' test requires a command argument. See 'sysbench mutex help' +root@ec078132e294:/# sysbench mutex --threads=8 run +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Initializing worker threads... + +Threads started! + + +General statistics: + total time: 0.2320s + total number of events: 8 + +Latency (ms): + min: 152.35 + avg: 192.48 + max: 231.41 + 95th percentile: 231.53 + sum: 1539.83 + +Threads fairness: + events (avg/stddev): 1.0000/0.00 + execution time (avg/stddev): 0.1925/0.04 +` + + sysbench := SysbenchMutex{} + want := .1925 + if got, err := sysbench.parseExecutionTime(sampleData); err != nil { + t.Fatalf("parse mutex time failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } + + want = 0.04 + if got, err := sysbench.parseDeviation(sampleData); err != nil { + t.Fatalf("parse mutex deviation failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } + + want = 192.48 + if got, err := sysbench.parseLatency(sampleData); err != nil { + t.Fatalf("parse mutex time failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/tools.go b/test/benchmarks/tools/tools.go new file mode 100644 index 000000000..eb61c0136 --- /dev/null +++ b/test/benchmarks/tools/tools.go @@ -0,0 +1,17 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tools holds tooling to couple command formatting and output parsers +// together. +package tools diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 44cce0e3b..29a84f184 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -23,6 +23,7 @@ go_test( "//pkg/test/dockerutil", "//pkg/test/testutil", "//runsc/specutils", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index 6a63b1232..b47df447c 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -22,12 +22,10 @@ package integration import ( + "context" "fmt" - "os" - "os/exec" "strconv" "strings" - "syscall" "testing" "time" @@ -39,18 +37,19 @@ import ( // Test that exec uses the exact same capability set as the container. func TestExecCapabilities(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { t.Fatalf("docker run failed: %v", err) } // Check that capability. - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) } @@ -61,7 +60,7 @@ func TestExecCapabilities(t *testing.T) { t.Log("Root capabilities:", want) // Now check that exec'd process capabilities match the root. - got, err := d.Exec(dockerutil.RunOpts{}, "grep", "CapEff:", "/proc/self/status") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -74,11 +73,12 @@ func TestExecCapabilities(t *testing.T) { // Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW // which is removed from the container when --net-raw=false. func TestExecPrivileged(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with all capabilities dropped. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", CapDrop: []string{"all"}, }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { @@ -86,7 +86,7 @@ func TestExecPrivileged(t *testing.T) { } // Check that all capabilities where dropped from container. - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) } @@ -104,7 +104,7 @@ func TestExecPrivileged(t *testing.T) { // Check that 'exec --privileged' adds all capabilities, except for // CAP_NET_RAW. - got, err := d.Exec(dockerutil.RunOpts{ + got, err := d.Exec(ctx, dockerutil.ExecOpts{ Privileged: true, }, "grep", "CapEff:", "/proc/self/status") if err != nil { @@ -118,76 +118,59 @@ func TestExecPrivileged(t *testing.T) { } func TestExecJobControl(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - // Exec 'sh' with an attached pty. - if _, err := d.Exec(dockerutil.RunOpts{ - Pty: func(cmd *exec.Cmd, ptmx *os.File) { - // Call "sleep 100 | cat" in the shell. We pipe to cat - // so that there will be two processes in the - // foreground process group. - if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep and - // cat, but not the shell. \x03 is ASCII "end of - // text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // The shell should still be alive at this point. Sleep - // should have exited with code 2+128=130. We'll exit - // with 10 plus that number, so that we can be sure - // that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Exec process should exit with code 10+130=140. - ps, err := cmd.Process.Wait() - if err != nil { - t.Fatalf("error waiting for exec process: %v", err) - } - ws := ps.Sys().(syscall.WaitStatus) - if !ws.Exited() { - t.Errorf("ws.Exited got false, want true") - } - if got, want := ws.ExitStatus(), 140; got != want { - t.Errorf("ws.ExitedStatus got %d, want %d", got, want) - } - }, - }, "sh"); err != nil { + p, err := d.ExecProcess(ctx, dockerutil.ExecOpts{UseTTY: true}, "/bin/sh") + if err != nil { t.Fatalf("docker exec failed: %v", err) } + + if _, err = p.Write(time.Second, []byte("sleep 100 | cat\n")); err != nil { + t.Fatalf("error exit: %v", err) + } + time.Sleep(time.Second) + + if _, err = p.Write(time.Second, []byte{0x03}); err != nil { + t.Fatalf("error exit: %v", err) + } + + if _, err = p.Write(time.Second, []byte("exit $(expr $? + 10)\n")); err != nil { + t.Fatalf("error exit: %v", err) + } + + want := 140 + got, err := p.WaitExitStatus(ctx) + if err != nil { + t.Fatalf("wait for exit failed with: %v", err) + } else if got != want { + t.Fatalf("wait for exit returned: %d want: %d", got, want) + } } // Test that failure to exec returns proper error message. func TestExecError(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } // Attempt to exec a binary that doesn't exist. - out, err := d.Exec(dockerutil.RunOpts{}, "no_can_find") + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "no_can_find") if err == nil { t.Fatalf("docker exec didn't fail") } @@ -198,11 +181,12 @@ func TestExecError(t *testing.T) { // Test that exec inherits environment from run. func TestExecEnv(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with env FOO=BAR. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", Env: []string{"FOO=BAR"}, }, "sleep", "1000"); err != nil { @@ -210,7 +194,7 @@ func TestExecEnv(t *testing.T) { } // Exec "echo $FOO". - got, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "echo $FOO") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $FOO") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -222,11 +206,12 @@ func TestExecEnv(t *testing.T) { // TestRunEnvHasHome tests that run always has HOME environment set. func TestRunEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Exec "echo $HOME". The 'bin' user's home dir is '/bin'. - got, err := d.Run(dockerutil.RunOpts{ + got, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", User: "bin", }, "/bin/sh", "-c", "echo $HOME") @@ -243,17 +228,18 @@ func TestRunEnvHasHome(t *testing.T) { // Test that exec always has HOME environment set, even when not set in run. func TestExecEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } // Exec "echo $HOME", and expect to see "/root". - got, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "echo $HOME") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -265,12 +251,12 @@ func TestExecEnvHasHome(t *testing.T) { newUID := 1234 newHome := "/foo/bar" cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome) - if _, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", cmd); err != nil { + if _, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", cmd); err != nil { t.Fatalf("docker exec failed: %v", err) } // Execute the same as the new user and expect newHome. - got, err = d.Exec(dockerutil.RunOpts{ + got, err = d.Exec(ctx, dockerutil.ExecOpts{ User: strconv.Itoa(newUID), }, "/bin/sh", "-c", "echo $HOME") if err != nil { diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 60e739c6a..8425abecb 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -22,24 +22,27 @@ package integration import ( + "context" "flag" "fmt" "io/ioutil" "net" "net/http" "os" - "os/exec" "path/filepath" "strconv" "strings" - "syscall" "testing" "time" + "github.com/docker/docker/api/types/mount" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/pkg/test/testutil" ) +// defaultWait is the default wait time used for tests. +const defaultWait = time.Minute + // httpRequestSucceeds sends a request to a given url and checks that the status is OK. func httpRequestSucceeds(client http.Client, server string, port int) error { url := fmt.Sprintf("http://%s:%d", server, port) @@ -56,37 +59,38 @@ func httpRequestSucceeds(client http.Client, server string, port int) error { // TestLifeCycle tests a basic Create/Start/Stop docker container life cycle. func TestLifeCycle(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Create(dockerutil.RunOpts{ + port := 80 + if err := d.Create(ctx, dockerutil.RunOpts{ Image: "basic/nginx", - Ports: []int{80}, + Ports: []int{port}, }); err != nil { t.Fatalf("docker create failed: %v", err) } - if err := d.Start(); err != nil { + if err := d.Start(ctx); err != nil { t.Fatalf("docker start failed: %v", err) } - // Test that container is working. - port, err := d.FindPort(80) + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } - client := http.Client{Timeout: time.Duration(2 * time.Second)} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + client := http.Client{Timeout: defaultWait} + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Errorf("http request failed: %v", err) } - if err := d.Stop(); err != nil { + if err := d.Stop(ctx); err != nil { t.Fatalf("docker stop failed: %v", err) } - if err := d.Remove(); err != nil { + if err := d.Remove(ctx); err != nil { t.Fatalf("docker rm failed: %v", err) } } @@ -96,40 +100,43 @@ func TestPauseResume(t *testing.T) { t.Skip("Checkpoint is not supported.") } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + port := 8080 + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", - Ports: []int{8080}, // See Dockerfile. + Ports: []int{port}, // See Dockerfile. }); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check that container is working. - client := http.Client{Timeout: time.Duration(2 * time.Second)} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + client := http.Client{Timeout: defaultWait} + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } - if err := d.Pause(); err != nil { + if err := d.Pause(ctx); err != nil { t.Fatalf("docker pause failed: %v", err) } // Check if container is paused. - switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) { + client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute. + switch _, err := client.Get(fmt.Sprintf("http://%s:%d", ip.String(), port)); v := err.(type) { case nil: t.Errorf("http req expected to fail but it succeeded") case net.Error: @@ -140,17 +147,18 @@ func TestPauseResume(t *testing.T) { t.Errorf("http req got unexpected error %v", v) } - if err := d.Unpause(); err != nil { + if err := d.Unpause(ctx); err != nil { t.Fatalf("docker unpause failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + client = http.Client{Timeout: defaultWait} + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } } @@ -160,70 +168,80 @@ func TestCheckpointRestore(t *testing.T) { t.Skip("Pause/resume is not supported.") } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + // TODO(gvisor.dev/issue/3373): Remove after implementing. + if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 { + t.Skip("CheckpointRestore not implemented in VFS2.") + } else if err != nil { + t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) + } + + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + port := 8080 + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", - Ports: []int{8080}, // See Dockerfile. + Ports: []int{port}, // See Dockerfile. }); err != nil { t.Fatalf("docker run failed: %v", err) } // Create a snapshot. - if err := d.Checkpoint("test"); err != nil { + if err := d.Checkpoint(ctx, "test"); err != nil { t.Fatalf("docker checkpoint failed: %v", err) } - if _, err := d.Wait(30 * time.Second); err != nil { + if err := d.WaitTimeout(ctx, defaultWait); err != nil { t.Fatalf("wait failed: %v", err) } // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed. - if err := testutil.Poll(func() error { return d.Restore("test") }, 15*time.Second); err != nil { + if err := testutil.Poll(func() error { return d.Restore(ctx, "test") }, defaultWait); err != nil { t.Fatalf("docker restore failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. - client := http.Client{Timeout: time.Duration(2 * time.Second)} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + client := http.Client{Timeout: defaultWait} + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } } // Create client and server that talk to each other using the local IP. func TestConnectToSelf(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Creates server that replies "server" and exists. Sleeps at the end because // 'docker exec' gets killed if the init process exists before it can finish. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/ubuntu", }, "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { t.Fatalf("docker run failed: %v", err) } // Finds IP address for host. - ip, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") + ip, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") if err != nil { t.Fatalf("docker exec failed: %v", err) } ip = strings.TrimRight(ip, "\n") // Runs client that sends "client" to the server and exits. - reply, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) + reply, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -232,21 +250,22 @@ func TestConnectToSelf(t *testing.T) { if want := "server\n"; reply != want { t.Errorf("Error on server, want: %q, got: %q", want, reply) } - if _, err := d.WaitForOutput("^client\n$", 1*time.Second); err != nil { + if _, err := d.WaitForOutput(ctx, "^client\n$", defaultWait); err != nil { t.Fatalf("docker.WaitForOutput(client) timeout: %v", err) } } func TestMemLimit(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // N.B. Because the size of the memory file may grow in large chunks, // there is a minimum threshold of 1GB for the MemTotal figure. - allocMemory := 1024 * 1024 - out, err := d.Run(dockerutil.RunOpts{ + allocMemory := 1024 * 1024 // In kb. + out, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", - Memory: allocMemory, // In kB. + Memory: allocMemory * 1024, // In bytes. }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'") if err != nil { t.Fatalf("docker run failed: %v", err) @@ -272,13 +291,14 @@ func TestMemLimit(t *testing.T) { } func TestNumCPU(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Read how many cores are in the container. - out, err := d.Run(dockerutil.RunOpts{ - Image: "basic/alpine", - Extra: []string{"--cpuset-cpus=0"}, + out, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + CpusetCpus: "0", }, "sh", "-c", "cat /proc/cpuinfo | grep 'processor.*:' | wc -l") if err != nil { t.Fatalf("docker run failed: %v", err) @@ -296,48 +316,34 @@ func TestNumCPU(t *testing.T) { // TestJobControl tests that job control characters are handled properly. func TestJobControl(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with an attached PTY. - if _, err := d.Run(dockerutil.RunOpts{ + p, err := d.SpawnProcess(ctx, dockerutil.RunOpts{ Image: "basic/alpine", - Pty: func(_ *exec.Cmd, ptmx *os.File) { - // Call "sleep 100" in the shell. - if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) + }, "sh", "-c", "sleep 100 | cat") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + // Give shell a few seconds to start executing the sleep. + time.Sleep(2 * time.Second) - // Send a ^C to the pty, which should kill sleep, but - // not the shell. \x03 is ASCII "end of text", which - // is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } + if _, err := p.Write(time.Second, []byte{0x03}); err != nil { + t.Fatalf("error exit: %v", err) + } - // The shell should still be alive at this point. Sleep - // should have exited with code 2+128=130. We'll exit - // with 10 plus that number, so that we can be sure - // that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - }, - }, "sh"); err != nil { - t.Fatalf("docker run failed: %v", err) + if err := d.WaitTimeout(ctx, 3*time.Second); err != nil { + t.Fatalf("WaitTimeout failed: %v", err) } - // Wait for the container to exit. - got, err := d.Wait(5 * time.Second) + want := 130 + got, err := p.WaitExitStatus(ctx) if err != nil { - t.Fatalf("error getting exit code: %v", err) - } - // Container should exit with code 10+130=140. - if want := syscall.WaitStatus(140); got != want { - t.Errorf("container exited with code %d want %d", got, want) + t.Fatalf("wait for exit failed with: %v", err) + } else if got != want { + t.Fatalf("got: %d want: %d", got, want) } } @@ -356,15 +362,16 @@ func TestWorkingDirCreation(t *testing.T) { name += "-readonly" } t.Run(name, func(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) opts := dockerutil.RunOpts{ Image: "basic/alpine", WorkDir: tc.workingDir, ReadOnly: readonly, } - got, err := d.Run(opts, "sh", "-c", "echo ${PWD}") + got, err := d.Run(ctx, opts, "sh", "-c", "echo ${PWD}") if err != nil { t.Fatalf("docker run failed: %v", err) } @@ -378,11 +385,12 @@ func TestWorkingDirCreation(t *testing.T) { // TestTmpFile checks that files inside '/tmp' are not overridden. func TestTmpFile(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - opts := dockerutil.RunOpts{Image: "tmpfile"} - got, err := d.Run(opts, "cat", "/tmp/foo/file.txt") + opts := dockerutil.RunOpts{Image: "basic/tmpfile"} + got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt") if err != nil { t.Fatalf("docker run failed: %v", err) } @@ -393,6 +401,7 @@ func TestTmpFile(t *testing.T) { // TestTmpMount checks that mounts inside '/tmp' are not overridden. func TestTmpMount(t *testing.T) { + ctx := context.Background() dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount") if err != nil { t.Fatalf("TempDir(): %v", err) @@ -401,19 +410,20 @@ func TestTmpMount(t *testing.T) { if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil { t.Fatalf("WriteFile(): %v", err) } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) opts := dockerutil.RunOpts{ Image: "basic/alpine", - Mounts: []dockerutil.Mount{ + Mounts: []mount.Mount{ { + Type: mount.TypeBind, Source: dir, Target: "/tmp/foo", }, }, } - got, err := d.Run(opts, "cat", "/tmp/foo/file.txt") + got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt") if err != nil { t.Fatalf("docker run failed: %v", err) } @@ -426,14 +436,61 @@ func TestTmpMount(t *testing.T) { // runsc to hide the incoherence of FDs opened before and after overlayfs // copy-up on the host. func TestHostOverlayfsCopyUp(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/hostoverlaytest", + WorkDir: "/root", + }, "./test_copy_up"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) + } +} + +// TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory +// stream to refer to the current state of the corresponding directory, as a +// call to opendir() would have done" as required by POSIX, when the directory +// in question is host overlayfs. +// +// This test specifically targets host overlayfs because, per POSIX, "if a file +// is removed from or added to the directory after the most recent call to +// opendir() or rewinddir(), whether a subsequent call to readdir() returns an +// entry for that file is unspecified"; the host filesystems used by other +// automated tests yield newly-added files from readdir() even if the fsgofer +// does not explicitly rewinddir(), but overlayfs does not. +func TestHostOverlayfsRewindDir(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/hostoverlaytest", + WorkDir: "/root", + }, "./test_rewinddir"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) + } +} + +// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it +// cannot use tricks like userns as root. For this reason, run a basic link test +// to ensure some coverage. +func TestLink(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if _, err := d.Run(dockerutil.RunOpts{ - Image: "hostoverlaytest", + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/linktest", WorkDir: "/root", - }, "./test"); err != nil { + }, "./link_test"); err != nil { t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) } } diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go index 327a2174c..70bbe5121 100644 --- a/test/e2e/regression_test.go +++ b/test/e2e/regression_test.go @@ -15,6 +15,7 @@ package integration import ( + "context" "strings" "testing" @@ -27,11 +28,12 @@ import ( // Prerequisite: the directory where the socket file is created must not have // been open for write before bind(2) is called. func TestBindOverlay(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Run the container. - got, err := d.Run(dockerutil.RunOpts{ + got, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/ubuntu", }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p") if err != nil { diff --git a/test/fuse/BUILD b/test/fuse/BUILD new file mode 100644 index 000000000..8e31fdd41 --- /dev/null +++ b/test/fuse/BUILD @@ -0,0 +1,73 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:stat_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:open_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:release_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:mknod_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:symlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:readlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:mkdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:read_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:write_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:rmdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:readdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:create_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:unlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:setstat_test", +) diff --git a/test/fuse/README.md b/test/fuse/README.md new file mode 100644 index 000000000..65add57e2 --- /dev/null +++ b/test/fuse/README.md @@ -0,0 +1,188 @@ +# gVisor FUSE Test Suite + +This is an integration test suite for fuse(4) filesystem. It runs under gVisor +sandbox container with VFS2 and FUSE function enabled. + +This document describes the framework of FUSE integration test, how to use it, +and the guidelines that should be followed when adding new testing features. + +## Integration Test Framework + +By inheriting the `FuseTest` class defined in `linux/fuse_base.h`, every test +fixture can run in an environment with `mount_point_` mounted by a fake FUSE +server. It creates a `socketpair(2)` to send and receive control commands and +data between the client and the server. Because the FUSE server runs in the +background thread, gTest cannot catch its assertion failure immediately. Thus, +`TearDown()` function sends command to the FUSE server to check if all gTest +assertion in the server are successful and all requests and preset responses are +consumed. + +## Communication Diagram + +Diagram below describes how a testing thread communicates with the FUSE server +to achieve integration test. + +For the following diagram, `>` means entering the function, `<` is leaving the +function, and `=` indicates sequentially entering and leaving. Not necessarily +follow exactly the below diagram due to the nature of a multi-threaded system, +however, it is still helpful to know when the client waits for the server to +complete a command and when the server awaits the next instruction. + +``` +| Client (Testing Thread) | Server (FUSE Server Thread) +| | +| >TEST_F() | +| >SetUp() | +| =MountFuse() | +| >SetUpFuseServer() | +| [create communication socket]| +| =fork() | =fork() +| [wait server complete] | +| | =ServerConsumeFuseInit() +| | =ServerCompleteWith() +| <SetUpFuseServer() | +| <SetUp() | +| [testing main] | +| | >ServerFuseLoop() +| | [poll on socket and fd] +| >SetServerResponse() | +| [write data to socket] | +| [wait server complete] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerReceiveResponse() +| | [read data from socket] +| | [save data to memory] +| | <ServerReceiveResponse() +| | =ServerCompleteWith() +| <SetServerResponse() | +| | <ServerHandleCommand() +| >[Do fs operation] | +| [wait for fs response] | +| | [fd event occurs] +| | >ServerProcessFuseRequest() +| | =[read fs request] +| | =[save fs request to memory] +| | =[write fs response] +| <[Do fs operation] | +| | <ServerProcessFuseRequest() +| | +| =[Test fs operation result] | +| | +| >GetServerActualRequest() | +| [write data to socket] | +| [wait data from server] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerSendReceivedRequest() +| | [write data to socket] +| [read data from socket] | +| [wait server complete] | +| | <ServerSendReceivedRequest() +| | =ServerCompleteWith() +| <GetServerActualRequest() | +| | <ServerHandleCommand() +| | +| =[Test actual request] | +| | +| >TearDown() | +| ... | +| >GetServerNumUnsentResponses() | +| [write data to socket] | +| [wait server complete] | +| | [socket event arrive] +| | >ServerHandleCommand() +| | >ServerSendData() +| | [write data to socket] +| | <ServerSendData() +| | =ServerCompleteWith() +| [read data from socket] | +| [test if all succeeded] | +| <GetServerNumUnsentResponses() | +| | <ServerHandleCommand() +| =UnmountFuse() | +| <TearDown() | +| <TEST_F() | +``` + +## Running the tests + +Based on syscall tests, FUSE tests generate targets only with vfs2 and fuse +enabled. The corresponding targets end in `_fuse`. + +For example, to run fuse test in `stat_test.cc`: + +```bash +$ bazel test //test/fuse:stat_test_runsc_ptrace_vfs2_fuse +``` + +Test all targets tagged with fuse: + +```bash +$ bazel test --test_tag_filters=fuse //test/fuse/... +``` + +## Writing a new FUSE test + +1. Add test targets in `BUILD` and `linux/BUILD`. +2. Inherit your test from `FuseTest` base class. It allows you to: + - Fork a fake FUSE server in background during each test setup. + - Create a pair of sockets for communication and provide utility + functions. + - Stop FUSE server and check if error occurs in it after test completes. +3. Build the expected opcode-response pairs of your FUSE operation. +4. Call `SetServerResponse()` to preset the next expected opcode and response. +5. Do real filesystem operations (FUSE is mounted at `mount_point_`). +6. Check FUSE response and/or errors. +7. Retrieve FUSE request by `GetServerActualRequest()`. +8. Check if the request is as expected. + +A few customized matchers used in syscalls test are encouraged to test the +outcome of filesystem operations. Such as: + +```cc +SyscallSucceeds() +SyscallSucceedsWithValue(...) +SyscallFails() +SyscallFailsWithErrno(...) +``` + +Please refer to [test/syscalls/README.md](../syscalls/README.md) for further +details. + +## Writing a new FuseTestCmd + +A `FuseTestCmd` is a control protocol used in the communication between the +testing thread and the FUSE server. Such commands are sent from the testing +thread to the FUSE server to set up, control, or inspect the behavior of the +FUSE server in response to a sequence of FUSE requests. + +The lifecycle of a command contains following steps: + +1. The testing thread sends a `FuseTestCmd` via socket and waits for + completion. +2. The FUSE server receives the command and does corresponding action. +3. (Optional) The testing thread reads data from socket. +4. The FUSE server sends a success indicator via socket after processing. +5. The testing thread gets the success signal and continues testing. + +The success indicator, i.e. `WaitServerComplete()`, is crucial at the end of +each `FuseTestCmd` sent from the testing thread. Because we don't want to begin +filesystem operation if the requests have not been completely set up. Also, to +test FUSE interactions in a sequential manner, concurrent requests are not +supported now. + +To add a new `FuseTestCmd`, one must comply with following format: + +1. Add a new `FuseTestCmd` enum class item defined in `linux/fuse_base.h` +2. Add a `SetServerXXX()` or `GetServerXXX()` public function in `FuseTest`. + This is how the testing thread will call to send control message. Define how + many bytes you want to send along with the command and what you will expect + to receive. Finally it should block and wait for a success indicator from + the FUSE server. +3. Add a handler logic in the switch condition of `ServerHandleCommand()`. Use + `ServerSendData()` or declare a new private function such as + `ServerReceiveXXX()` or `ServerSendXXX()`. It is mandatory to set it private + since only the FUSE server (forked from `FuseTest` base class) can call it. + This is the server part of the specific `FuseTestCmd` and the format of the + data should be consistent with what the client expects in the previous step. diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD new file mode 100644 index 000000000..7673252ec --- /dev/null +++ b/test/fuse/linux/BUILD @@ -0,0 +1,230 @@ +load("//tools:defs.bzl", "cc_binary", "cc_library", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "stat_test", + testonly = 1, + srcs = ["stat_test.cc"], + deps = [ + gtest, + ":fuse_fd_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "open_test", + testonly = 1, + srcs = ["open_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "release_test", + testonly = 1, + srcs = ["release_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "mknod_test", + testonly = 1, + srcs = ["mknod_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "symlink_test", + testonly = 1, + srcs = ["symlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "readlink_test", + testonly = 1, + srcs = ["readlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "mkdir_test", + testonly = 1, + srcs = ["mkdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "setstat_test", + testonly = 1, + srcs = ["setstat_test.cc"], + deps = [ + gtest, + ":fuse_fd_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "rmdir_test", + testonly = 1, + srcs = ["rmdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "readdir_test", + testonly = 1, + srcs = ["readdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_library( + name = "fuse_base", + testonly = 1, + srcs = ["fuse_base.cc"], + hdrs = ["fuse_base.h"], + deps = [ + gtest, + "//test/util:fuse_util", + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_util", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "fuse_fd_util", + testonly = 1, + srcs = ["fuse_fd_util.cc"], + hdrs = ["fuse_fd_util.h"], + deps = [ + gtest, + ":fuse_base", + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:fuse_util", + "//test/util:posix_error", + ], +) + +cc_binary( + name = "read_test", + testonly = 1, + srcs = ["read_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "write_test", + testonly = 1, + srcs = ["write_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "create_test", + testonly = 1, + srcs = ["create_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "unlink_test", + testonly = 1, + srcs = ["unlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/fuse/linux/create_test.cc b/test/fuse/linux/create_test.cc new file mode 100644 index 000000000..9a0219a58 --- /dev/null +++ b/test/fuse/linux/create_test.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class CreateTest : public FuseTest { + protected: + const std::string test_file_name_ = "test_file"; + const mode_t mode = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(CreateTest, CreateFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_name_); + + // Ensure the file doesn't exist. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_LOOKUP, iov_out); + + // creat(2) is equal to open(2) with open_flags O_CREAT | O_WRONLY | O_TRUNC. + const mode_t new_mask = S_IWGRP | S_IWOTH; + const int open_flags = O_CREAT | O_WRONLY | O_TRUNC; + out_header.error = 0; + out_header.len = sizeof(struct fuse_out_header) + + sizeof(struct fuse_entry_out) + sizeof(struct fuse_open_out); + struct fuse_entry_out entry_payload = DefaultEntryOut(mode & ~new_mask, 2); + struct fuse_open_out out_payload = { + .fh = 1, + .open_flags = open_flags, + }; + iov_out = FuseGenerateIovecs(out_header, entry_payload, out_payload); + SetServerResponse(FUSE_CREATE, iov_out); + + // kernfs generates a successive FUSE_OPEN after the file is created. Linux's + // fuse kernel module will not send this FUSE_OPEN after creat(2). + out_header.len = + sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out); + iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + + int fd; + TempUmask mask(new_mask); + EXPECT_THAT(fd = creat(test_file_path.c_str(), mode), SyscallSucceeds()); + EXPECT_THAT(fcntl(fd, F_GETFL), + SyscallSucceedsWithValue(open_flags & O_ACCMODE)); + + struct fuse_in_header in_header; + struct fuse_create_in in_payload; + std::vector<char> name(test_file_name_.size() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, name); + + // Skip the request of FUSE_LOOKUP. + SkipServerActualRequest(); + + // Get the first FUSE_CREATE. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload) + + test_file_name_.size() + 1); + EXPECT_EQ(in_header.opcode, FUSE_CREATE); + EXPECT_EQ(in_payload.flags, open_flags); + EXPECT_EQ(in_payload.mode, mode & ~new_mask); + EXPECT_EQ(in_payload.umask, new_mask); + EXPECT_EQ(std::string(name.data()), test_file_name_); + + // Get the successive FUSE_OPEN. + struct fuse_open_in in_payload_open; + iov_in = FuseGenerateIovecs(in_header, in_payload_open); + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload_open)); + EXPECT_EQ(in_header.opcode, FUSE_OPEN); + EXPECT_EQ(in_payload_open.flags, open_flags & O_ACCMODE); + + EXPECT_THAT(close(fd), SyscallSucceeds()); + // Skip the FUSE_RELEASE. + SkipServerActualRequest(); +} + +TEST_F(CreateTest, CreateFileAlreadyExists) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_name_); + + const int open_flags = O_CREAT | O_EXCL; + + SetServerInodeLookup(test_file_name_); + + EXPECT_THAT(open(test_file_path.c_str(), mode, open_flags), + SyscallFailsWithErrno(EEXIST)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc new file mode 100644 index 000000000..5b45804e1 --- /dev/null +++ b/test/fuse/linux/fuse_base.cc @@ -0,0 +1,447 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/fuse/linux/fuse_base.h" + +#include <fcntl.h> +#include <linux/fuse.h> +#include <poll.h> +#include <sys/mount.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/strings/str_format.h" +#include "test/util/fuse_util.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +void FuseTest::SetUp() { + MountFuse(); + SetUpFuseServer(); +} + +void FuseTest::TearDown() { + EXPECT_EQ(GetServerNumUnconsumedRequests(), 0); + EXPECT_EQ(GetServerNumUnsentResponses(), 0); + UnmountFuse(); +} + +// Sends 3 parts of data to the FUSE server: +// 1. The `kSetResponse` command +// 2. The expected opcode +// 3. The fake FUSE response +// Then waits for the FUSE server to notify its completion. +void FuseTest::SetServerResponse(uint32_t opcode, + std::vector<struct iovec>& iovecs) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetResponse); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &opcode, sizeof(opcode)), + SyscallSucceedsWithValue(sizeof(opcode))); + + EXPECT_THAT(RetryEINTR(writev)(sock_[0], iovecs.data(), iovecs.size()), + SyscallSucceeds()); + + WaitServerComplete(); +} + +// Waits for the FUSE server to finish its blocking job and check if it +// completes without errors. +void FuseTest::WaitServerComplete() { + uint32_t success; + EXPECT_THAT(RetryEINTR(read)(sock_[0], &success, sizeof(success)), + SyscallSucceedsWithValue(sizeof(success))); + ASSERT_EQ(success, 1); +} + +// Sends the `kGetRequest` command to the FUSE server, then reads the next +// request into iovec struct. The order of calling this function should be +// the same as the one of SetServerResponse(). +void FuseTest::GetServerActualRequest(std::vector<struct iovec>& iovecs) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kGetRequest); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(readv)(sock_[0], iovecs.data(), iovecs.size()), + SyscallSucceeds()); + + WaitServerComplete(); +} + +// Sends a FuseTestCmd command to the FUSE server, reads from the socket, and +// returns the corresponding data. +uint32_t FuseTest::GetServerData(uint32_t cmd) { + uint32_t data; + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(read)(sock_[0], &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + WaitServerComplete(); + return data; +} + +uint32_t FuseTest::GetServerNumUnconsumedRequests() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetNumUnconsumedRequests)); +} + +uint32_t FuseTest::GetServerNumUnsentResponses() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetNumUnsentResponses)); +} + +uint32_t FuseTest::GetServerTotalReceivedBytes() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetTotalReceivedBytes)); +} + +// Sends the `kSkipRequest` command to the FUSE server, which would skip +// current stored request data. +void FuseTest::SkipServerActualRequest() { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSkipRequest); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + WaitServerComplete(); +} + +// Sends the `kSetInodeLookup` command, expected mode, and the path of the +// inode to create under the mount point. +void FuseTest::SetServerInodeLookup(const std::string& path, mode_t mode, + uint64_t size) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetInodeLookup); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &mode, sizeof(mode)), + SyscallSucceedsWithValue(sizeof(mode))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &size, sizeof(size)), + SyscallSucceedsWithValue(sizeof(size))); + + // Pad 1 byte for null-terminate c-string. + EXPECT_THAT(RetryEINTR(write)(sock_[0], path.c_str(), path.size() + 1), + SyscallSucceedsWithValue(path.size() + 1)); + + WaitServerComplete(); +} + +void FuseTest::MountFuse(const char* mountOpts) { + EXPECT_THAT(dev_fd_ = open("/dev/fuse", O_RDWR), SyscallSucceeds()); + + std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, mountOpts); + mount_point_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(mount("fuse", mount_point_.path().c_str(), "fuse", + MS_NODEV | MS_NOSUID, mount_opts.c_str()), + SyscallSucceeds()); +} + +void FuseTest::UnmountFuse() { + EXPECT_THAT(umount(mount_point_.path().c_str()), SyscallSucceeds()); + // TODO(gvisor.dev/issue/3330): ensure the process is terminated successfully. +} + +// Consumes the first FUSE request and returns the corresponding PosixError. +PosixError FuseTest::ServerConsumeFuseInit( + const struct fuse_init_out* out_payload) { + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + RETURN_ERROR_IF_SYSCALL_FAIL( + RetryEINTR(read)(dev_fd_, buf.data(), buf.size())); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_init_out), + .error = 0, + .unique = 2, + }; + // Returns a fake fuse_init_out with 7.0 version to avoid ECONNREFUSED + // error in the initialization of FUSE connection. + auto iov_out = FuseGenerateIovecs( + out_header, *const_cast<struct fuse_init_out*>(out_payload)); + + RETURN_ERROR_IF_SYSCALL_FAIL( + RetryEINTR(writev)(dev_fd_, iov_out.data(), iov_out.size())); + return NoError(); +} + +// Reads 1 expected opcode and a fake response from socket and save them into +// the serial buffer of this testing instance. +void FuseTest::ServerReceiveResponse() { + ssize_t len; + uint32_t opcode; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + EXPECT_THAT(RetryEINTR(read)(sock_[1], &opcode, sizeof(opcode)), + SyscallSucceedsWithValue(sizeof(opcode))); + + EXPECT_THAT(len = RetryEINTR(read)(sock_[1], buf.data(), buf.size()), + SyscallSucceeds()); + + responses_.AddMemBlock(opcode, buf.data(), len); +} + +// Writes 1 byte of success indicator through socket. +void FuseTest::ServerCompleteWith(bool success) { + uint32_t data = success ? 1 : 0; + ServerSendData(data); +} + +// ServerFuseLoop is the implementation of the fake FUSE server. Monitors 2 +// file descriptors: /dev/fuse and sock_[1]. Events from /dev/fuse are FUSE +// requests and events from sock_[1] are FUSE testing commands, leading by +// a FuseTestCmd data to indicate the command. +void FuseTest::ServerFuseLoop() { + const int nfds = 2; + struct pollfd fds[nfds] = { + { + .fd = dev_fd_, + .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL, + }, + { + .fd = sock_[1], + .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL, + }, + }; + + while (true) { + ASSERT_THAT(poll(fds, nfds, -1), SyscallSucceeds()); + + for (int fd_idx = 0; fd_idx < nfds; ++fd_idx) { + if (fds[fd_idx].revents == 0) continue; + + ASSERT_EQ(fds[fd_idx].revents, POLL_IN); + if (fds[fd_idx].fd == sock_[1]) { + ServerHandleCommand(); + } else if (fds[fd_idx].fd == dev_fd_) { + ServerProcessFuseRequest(); + } + } + } +} + +// SetUpFuseServer creates 1 socketpair and fork the process. The parent thread +// becomes testing thread and the child thread becomes the FUSE server running +// in background. These 2 threads are connected via socketpair. sock_[0] is +// opened in testing thread and sock_[1] is opened in the FUSE server. +void FuseTest::SetUpFuseServer(const struct fuse_init_out* payload) { + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_), SyscallSucceeds()); + + switch (fork()) { + case -1: + GTEST_FAIL(); + return; + case 0: + break; + default: + ASSERT_THAT(close(sock_[1]), SyscallSucceeds()); + WaitServerComplete(); + return; + } + + // Begin child thread, i.e. the FUSE server. + ASSERT_THAT(close(sock_[0]), SyscallSucceeds()); + ServerCompleteWith(ServerConsumeFuseInit(payload).ok()); + ServerFuseLoop(); + _exit(0); +} + +void FuseTest::ServerSendData(uint32_t data) { + EXPECT_THAT(RetryEINTR(write)(sock_[1], &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); +} + +// Reads FuseTestCmd sent from testing thread and routes to correct handler. +// Since each command should be a blocking operation, a `ServerCompleteWith()` +// is required after the switch keyword. +void FuseTest::ServerHandleCommand() { + uint32_t cmd; + EXPECT_THAT(RetryEINTR(read)(sock_[1], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + switch (static_cast<FuseTestCmd>(cmd)) { + case FuseTestCmd::kSetResponse: + ServerReceiveResponse(); + break; + case FuseTestCmd::kSetInodeLookup: + ServerReceiveInodeLookup(); + break; + case FuseTestCmd::kGetRequest: + ServerSendReceivedRequest(); + break; + case FuseTestCmd::kGetTotalReceivedBytes: + ServerSendData(static_cast<uint32_t>(requests_.UsedBytes())); + break; + case FuseTestCmd::kGetNumUnconsumedRequests: + ServerSendData(static_cast<uint32_t>(requests_.RemainingBlocks())); + break; + case FuseTestCmd::kGetNumUnsentResponses: + ServerSendData(static_cast<uint32_t>(responses_.RemainingBlocks())); + break; + case FuseTestCmd::kSkipRequest: + ServerSkipReceivedRequest(); + break; + default: + FAIL() << "Unknown FuseTestCmd " << cmd; + break; + } + + ServerCompleteWith(!HasFailure()); +} + +// Reads the expected file mode and the path of one file. Crafts a basic +// `fuse_entry_out` memory block and inserts into a map for future use. +// The FUSE server will always return this response if a FUSE_LOOKUP +// request with this specific path comes in. +void FuseTest::ServerReceiveInodeLookup() { + mode_t mode; + uint64_t size; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], &mode, sizeof(mode)), + SyscallSucceedsWithValue(sizeof(mode))); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], &size, sizeof(size)), + SyscallSucceedsWithValue(sizeof(size))); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], buf.data(), buf.size()), + SyscallSucceeds()); + + std::string path(buf.data()); + + uint32_t out_len = + sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out); + struct fuse_out_header out_header = { + .len = out_len, + .error = 0, + }; + struct fuse_entry_out out_payload = DefaultEntryOut(mode, nodeid_); + // Since this is only used in test, nodeid_ is simply increased by 1 to + // comply with the unqiueness of different path. + ++nodeid_; + + // Set the size. + out_payload.attr.size = size; + + memcpy(buf.data(), &out_header, sizeof(out_header)); + memcpy(buf.data() + sizeof(out_header), &out_payload, sizeof(out_payload)); + lookups_.AddMemBlock(FUSE_LOOKUP, buf.data(), out_len); + lookup_map_[path] = lookups_.Next(); +} + +// Sends the received request pointed by current cursor and advances cursor. +void FuseTest::ServerSendReceivedRequest() { + if (requests_.End()) { + FAIL() << "No more received request."; + return; + } + auto mem_block = requests_.Next(); + EXPECT_THAT( + RetryEINTR(write)(sock_[1], requests_.DataAtOffset(mem_block.offset), + mem_block.len), + SyscallSucceedsWithValue(mem_block.len)); +} + +// Skip the request pointed by current cursor. +void FuseTest::ServerSkipReceivedRequest() { + if (requests_.End()) { + FAIL() << "No more received request."; + return; + } + requests_.Next(); +} + +// Handles FUSE request. Reads request from /dev/fuse, checks if it has the +// same opcode as expected, and responds with the saved fake FUSE response. +// The FUSE request is copied to the serial buffer and can be retrieved one- +// by-one by calling GetServerActualRequest from testing thread. +void FuseTest::ServerProcessFuseRequest() { + ssize_t len; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + + // Read FUSE request. + EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf.data(), buf.size()), + SyscallSucceeds()); + fuse_in_header* in_header = reinterpret_cast<fuse_in_header*>(buf.data()); + + // Check if this is a preset FUSE_LOOKUP path. + if (in_header->opcode == FUSE_LOOKUP) { + std::string path(buf.data() + sizeof(struct fuse_in_header)); + auto it = lookup_map_.find(path); + if (it != lookup_map_.end()) { + // Matches a preset path. Reply with fake data and skip saving the + // request. + ServerRespondFuseSuccess(lookups_, it->second, in_header->unique); + return; + } + } + + requests_.AddMemBlock(in_header->opcode, buf.data(), len); + + if (in_header->opcode == FUSE_RELEASE || in_header->opcode == FUSE_RELEASEDIR) + return; + // Check if there is a corresponding response. + if (responses_.End()) { + GTEST_NONFATAL_FAILURE_("No more FUSE response is expected"); + ServerRespondFuseError(in_header->unique); + return; + } + auto mem_block = responses_.Next(); + if (in_header->opcode != mem_block.opcode) { + std::string message = absl::StrFormat("Expect opcode %d but got %d", + mem_block.opcode, in_header->opcode); + GTEST_NONFATAL_FAILURE_(message.c_str()); + // We won't get correct response if opcode is not expected. Send error + // response here to avoid wrong parsing by VFS. + ServerRespondFuseError(in_header->unique); + return; + } + + // Write FUSE response. + ServerRespondFuseSuccess(responses_, mem_block, in_header->unique); +} + +void FuseTest::ServerRespondFuseSuccess(FuseMemBuffer& mem_buf, + const FuseMemBlock& block, + uint64_t unique) { + fuse_out_header* out_header = + reinterpret_cast<fuse_out_header*>(mem_buf.DataAtOffset(block.offset)); + + // Patch `unique` in fuse_out_header to avoid EINVAL caused by responding + // with an unknown `unique`. + out_header->unique = unique; + EXPECT_THAT(RetryEINTR(write)(dev_fd_, out_header, block.len), + SyscallSucceedsWithValue(block.len)); +} + +void FuseTest::ServerRespondFuseError(uint64_t unique) { + fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = ENOSYS, + .unique = unique, + }; + EXPECT_THAT(RetryEINTR(write)(dev_fd_, &out_header, sizeof(out_header)), + SyscallSucceedsWithValue(sizeof(out_header))); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_base.h b/test/fuse/linux/fuse_base.h new file mode 100644 index 000000000..6ad296ca2 --- /dev/null +++ b/test/fuse/linux/fuse_base.h @@ -0,0 +1,251 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_FUSE_FUSE_BASE_H_ +#define GVISOR_TEST_FUSE_FUSE_BASE_H_ + +#include <linux/fuse.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/uio.h> + +#include <iostream> +#include <unordered_map> +#include <vector> + +#include "gtest/gtest.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +constexpr char kMountOpts[] = "rootmode=755,user_id=0,group_id=0"; + +constexpr struct fuse_init_out kDefaultFUSEInitOutPayload = {.major = 7}; + +// Internal commands used to communicate between testing thread and the FUSE +// server. See test/fuse/README.md for further detail. +enum class FuseTestCmd { + kSetResponse = 0, + kSetInodeLookup, + kGetRequest, + kGetNumUnconsumedRequests, + kGetNumUnsentResponses, + kGetTotalReceivedBytes, + kSkipRequest, +}; + +// Holds the information of a memory block in a serial buffer. +struct FuseMemBlock { + uint32_t opcode; + size_t offset; + size_t len; +}; + +// A wrapper of a simple serial buffer that can be used with read(2) and +// write(2). Contains a cursor to indicate accessing. This class is not thread- +// safe and can only be used in single-thread version. +class FuseMemBuffer { + public: + FuseMemBuffer() : cursor_(0) { + // To read from /dev/fuse, a buffer needs at least FUSE_MIN_READ_BUFFER + // bytes to avoid EINVAL. FuseMemBuffer holds memory that can accommodate + // a sequence of FUSE request/response, so it is initiated with double + // minimal requirement. + mem_.resize(FUSE_MIN_READ_BUFFER * 2); + } + + // Returns whether there is no memory block. + bool Empty() { return blocks_.empty(); } + + // Returns if there is no more remaining memory blocks. + bool End() { return cursor_ == blocks_.size(); } + + // Returns how many bytes that have been received. + size_t UsedBytes() { + return Empty() ? 0 : blocks_.back().offset + blocks_.back().len; + } + + // Returns the available bytes remains in the serial buffer. + size_t AvailBytes() { return mem_.size() - UsedBytes(); } + + // Appends a memory block information that starts at the tail of the serial + // buffer. /dev/fuse requires at least FUSE_MIN_READ_BUFFER bytes to read, or + // it will issue EINVAL. If it is not enough, just double the buffer length. + void AddMemBlock(uint32_t opcode, void* data, size_t len) { + if (AvailBytes() < FUSE_MIN_READ_BUFFER) { + mem_.resize(mem_.size() << 1); + } + size_t offset = UsedBytes(); + memcpy(mem_.data() + offset, data, len); + blocks_.push_back(FuseMemBlock{opcode, offset, len}); + } + + // Returns the memory address at a specific offset. Used with read(2) or + // write(2). + char* DataAtOffset(size_t offset) { return mem_.data() + offset; } + + // Returns current memory block pointed by the cursor and increase by 1. + FuseMemBlock Next() { + if (End()) { + std::cerr << "Buffer is already exhausted." << std::endl; + return FuseMemBlock{}; + } + return blocks_[cursor_++]; + } + + // Returns the number of the blocks that has not been requested. + size_t RemainingBlocks() { return blocks_.size() - cursor_; } + + private: + size_t cursor_; + std::vector<FuseMemBlock> blocks_; + std::vector<char> mem_; +}; + +// FuseTest base class is useful in FUSE integration test. Inherit this class +// to automatically set up a fake FUSE server and use the member functions +// to manipulate with it. Refer to test/fuse/README.md for detailed explanation. +class FuseTest : public ::testing::Test { + public: + // nodeid_ is the ID of a fake inode. We starts from 2 since 1 is occupied by + // the mount point. + FuseTest() : nodeid_(2) {} + void SetUp() override; + void TearDown() override; + + // Called by the testing thread to set up a fake response for an expected + // opcode via socket. This can be used multiple times to define a sequence of + // expected FUSE reactions. + void SetServerResponse(uint32_t opcode, std::vector<struct iovec>& iovecs); + + // Called by the testing thread to install a fake path under the mount point. + // e.g. a file under /mnt/dir/file and moint point is /mnt, then it will look + // up "dir/file" in this case. + // + // It sets a fixed response to the FUSE_LOOKUP requests issued with this + // path, pretending there is an inode and avoid ENOENT when testing. If mode + // is not given, it creates a regular file with mode 0600. + void SetServerInodeLookup(const std::string& path, + mode_t mode = S_IFREG | S_IRUSR | S_IWUSR, + uint64_t size = 512); + + // Called by the testing thread to ask the FUSE server for its next received + // FUSE request. Be sure to use the corresponding struct of iovec to receive + // data from server. + void GetServerActualRequest(std::vector<struct iovec>& iovecs); + + // Called by the testing thread to query the number of unconsumed requests in + // the requests_ serial buffer of the FUSE server. TearDown() ensures all + // FUSE requests received by the FUSE server were consumed by the testing + // thread. + uint32_t GetServerNumUnconsumedRequests(); + + // Called by the testing thread to query the number of unsent responses in + // the responses_ serial buffer of the FUSE server. TearDown() ensures all + // preset FUSE responses were sent out by the FUSE server. + uint32_t GetServerNumUnsentResponses(); + + // Called by the testing thread to ask the FUSE server for its total received + // bytes from /dev/fuse. + uint32_t GetServerTotalReceivedBytes(); + + // Called by the testing thread to ask the FUSE server to skip stored + // request data. + void SkipServerActualRequest(); + + protected: + TempPath mount_point_; + + // Opens /dev/fuse and inherit the file descriptor for the FUSE server. + void MountFuse(const char* mountOpts = kMountOpts); + + // Creates a socketpair for communication and forks FUSE server. + void SetUpFuseServer( + const struct fuse_init_out* payload = &kDefaultFUSEInitOutPayload); + + // Unmounts the mountpoint of the FUSE server. + void UnmountFuse(); + + private: + // Sends a FuseTestCmd and gets a uint32_t data from the FUSE server. + inline uint32_t GetServerData(uint32_t cmd); + + // Waits for FUSE server to complete its processing. Complains if the FUSE + // server responds any failure during tests. + void WaitServerComplete(); + + // The FUSE server stays here and waits next command or FUSE request until it + // is terminated. + void ServerFuseLoop(); + + // Used by the FUSE server to tell testing thread if it is OK to proceed next + // command. Will be issued after processing each FuseTestCmd. + void ServerCompleteWith(bool success); + + // Consumes the first FUSE request when mounting FUSE. Replies with a + // response with empty payload. + PosixError ServerConsumeFuseInit(const struct fuse_init_out* payload); + + // A command switch that dispatch different FuseTestCmd to its handler. + void ServerHandleCommand(); + + // The FUSE server side's corresponding code of `SetServerResponse()`. + // Handles `kSetResponse` command. Saves the fake response into its output + // memory queue. + void ServerReceiveResponse(); + + // The FUSE server side's corresponding code of `SetServerInodeLookup()`. + // Handles `kSetInodeLookup` command. Receives an expected file mode and + // file path under the mount point. + void ServerReceiveInodeLookup(); + + // The FUSE server side's corresponding code of `GetServerActualRequest()`. + // Handles `kGetRequest` command. Sends the next received request pointed by + // the cursor. + void ServerSendReceivedRequest(); + + // Sends a uint32_t data via socket. + inline void ServerSendData(uint32_t data); + + // The FUSE server side's corresponding code of `SkipServerActualRequest()`. + // Handles `kSkipRequest` command. Skip the request pointed by current cursor. + void ServerSkipReceivedRequest(); + + // Handles FUSE request sent to /dev/fuse by its saved responses. + void ServerProcessFuseRequest(); + + // Responds to FUSE request with a saved data. + void ServerRespondFuseSuccess(FuseMemBuffer& mem_buf, + const FuseMemBlock& block, uint64_t unique); + + // Responds an error header to /dev/fuse when bad thing happens. + void ServerRespondFuseError(uint64_t unique); + + int dev_fd_; + int sock_[2]; + + uint64_t nodeid_; + std::unordered_map<std::string, FuseMemBlock> lookup_map_; + + FuseMemBuffer requests_; + FuseMemBuffer responses_; + FuseMemBuffer lookups_; +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_FUSE_FUSE_BASE_H_ diff --git a/test/fuse/linux/fuse_fd_util.cc b/test/fuse/linux/fuse_fd_util.cc new file mode 100644 index 000000000..30d1157bb --- /dev/null +++ b/test/fuse/linux/fuse_fd_util.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/fuse/linux/fuse_fd_util.h" + +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/uio.h> + +#include <string> +#include <vector> + +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/fuse_util.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<FileDescriptor> FuseFdTest::OpenPath(const std::string &path, + uint32_t flags, uint64_t fh) { + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload = { + .fh = fh, + .open_flags = flags, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + + auto res = Open(path.c_str(), flags); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; +} + +Cleanup FuseFdTest::CloseFD(FileDescriptor &fd) { + return Cleanup([&] { + close(fd.release()); + SkipServerActualRequest(); + }); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_fd_util.h b/test/fuse/linux/fuse_fd_util.h new file mode 100644 index 000000000..066185c94 --- /dev/null +++ b/test/fuse/linux/fuse_fd_util.h @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ +#define GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ + +#include <fcntl.h> +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> + +#include "test/fuse/linux/fuse_base.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +class FuseFdTest : public FuseTest { + public: + // Sets the FUSE server to respond to a FUSE_OPEN with corresponding flags and + // fh. Then does a real file system open on the absolute path to get an fd. + PosixErrorOr<FileDescriptor> OpenPath(const std::string &path, + uint32_t flags = O_RDONLY, + uint64_t fh = 1); + + // Returns a cleanup object that closes the fd when it is destroyed. After + // the close is done, tells the FUSE server to skip this FUSE_RELEASE. + Cleanup CloseFD(FileDescriptor &fd); +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ diff --git a/test/fuse/linux/mkdir_test.cc b/test/fuse/linux/mkdir_test.cc new file mode 100644 index 000000000..9647cb93f --- /dev/null +++ b/test/fuse/linux/mkdir_test.cc @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MkdirTest : public FuseTest { + protected: + const std::string test_dir_ = "test_dir"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(MkdirTest, CreateDir) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_); + const mode_t new_umask = 0077; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKDIR, iov_out); + TempUmask mask(new_umask); + ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_mkdir_in in_payload; + std::vector<char> actual_dir(test_dir_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_dir); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + sizeof(in_payload) + test_dir_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_MKDIR); + EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask); + EXPECT_EQ(in_payload.umask, new_umask); + EXPECT_EQ(std::string(actual_dir.data()), test_dir_); +} + +TEST_F(MkdirTest, FileTypeError) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKDIR, iov_out); + ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/mknod_test.cc b/test/fuse/linux/mknod_test.cc new file mode 100644 index 000000000..74c74d76b --- /dev/null +++ b/test/fuse/linux/mknod_test.cc @@ -0,0 +1,107 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MknodTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(MknodTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + const mode_t new_umask = 0077; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + TempUmask mask(new_umask); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_mknod_in in_payload; + std::vector<char> actual_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + sizeof(in_payload) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_MKNOD); + EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask); + EXPECT_EQ(in_payload.umask, new_umask); + EXPECT_EQ(in_payload.rdev, 0); + EXPECT_EQ(std::string(actual_file.data()), test_file_); +} + +TEST_F(MknodTest, FileTypeError) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + // server return directory instead of regular file should cause an error. + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +TEST_F(MknodTest, NodeIDError) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = + DefaultEntryOut(S_IFREG | perms_, FUSE_ROOT_ID); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/open_test.cc b/test/fuse/linux/open_test.cc new file mode 100644 index 000000000..4b0c4a805 --- /dev/null +++ b/test/fuse/linux/open_test.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class OpenTest : public FuseTest { + // OpenTest doesn't care the release request when close a fd, + // so doesn't check leftover requests when tearing down. + void TearDown() { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t regular_file_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + + struct fuse_open_out out_payload_ = { + .fh = 1, + .open_flags = O_RDWR, + }; +}; + +TEST_F(OpenTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, regular_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPEN, iov_out); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + + struct fuse_in_header in_header; + struct fuse_open_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_OPEN); + EXPECT_EQ(in_payload.flags, O_RDWR); + EXPECT_THAT(fcntl(fd.get(), F_GETFL), SyscallSucceedsWithValue(O_RDWR)); +} + +TEST_F(OpenTest, SetNoOpen) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, regular_file_); + + // ENOSYS indicates open is not implemented. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + .error = -ENOSYS, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPEN, iov_out); + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + SkipServerActualRequest(); + + // check open doesn't send new request. + uint32_t recieved_before = GetServerTotalReceivedBytes(); + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before); +} + +TEST_F(OpenTest, OpenFail) { + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + .error = -ENOENT, + }; + + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPENDIR, iov_out); + ASSERT_THAT(open(mount_point_.path().c_str(), O_RDWR), + SyscallFailsWithErrno(ENOENT)); + + struct fuse_in_header in_header; + struct fuse_open_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_OPENDIR); + EXPECT_EQ(in_payload.flags, O_RDWR); +} + +TEST_F(OpenTest, DirectoryFlagOnRegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + SetServerInodeLookup(test_file_, regular_file_); + ASSERT_THAT(open(test_file_path.c_str(), O_RDWR | O_DIRECTORY), + SyscallFailsWithErrno(ENOTDIR)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/read_test.cc b/test/fuse/linux/read_test.cc new file mode 100644 index 000000000..88fc299d8 --- /dev/null +++ b/test/fuse/linux/read_test.cc @@ -0,0 +1,390 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReadTest : public FuseTest { + void SetUp() override { + FuseTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + // TearDown overrides the parent's function + // to skip checking the unconsumed release request at the end. + void TearDown() override { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + const uint64_t test_fh_ = 1; + const uint32_t open_flag_ = O_RDWR; + + std::string test_file_path_; + + PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path, + uint64_t size = 512) { + SetServerInodeLookup(test_file_, test_file_mode_, size); + + struct fuse_out_header out_header_open = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload_open = { + .fh = test_fh_, + .open_flags = open_flag_, + }; + auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open); + SetServerResponse(FUSE_OPEN, iov_out_open); + + auto res = Open(path.c_str(), open_flag_); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; + } +}; + +class ReadTestSmallMaxRead : public ReadTest { + void SetUp() override { + MountFuse(mountOpts); + SetUpFuseServer(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + protected: + constexpr static char mountOpts[] = + "rootmode=755,user_id=0,group_id=0,max_read=4096"; + // 4096 is hard-coded as the max_read in mount options. + const int size_fragment = 4096; +}; + +TEST_F(ReadTest, ReadWhole) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the read. + const int n_read = 5; + std::vector<char> data(n_read); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + EXPECT_EQ(buf, data); +} + +TEST_F(ReadTest, ReadPartial) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the read. + const int n_data = 10; + std::vector<char> data(n_data); + RandomizeBuffer(data.data(), data.size()); + // Note: due to read ahead, current read implementation will treat any + // response that is longer than requested as correct (i.e. not reach the EOF). + // Therefore, the test below should make sure the size to read does not exceed + // n_data. + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + std::vector<char> buf(n_data); + + // Read 1 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 1), SyscallSucceedsWithValue(1)); + + // Check the 1-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + + // Read 3 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 3), SyscallSucceedsWithValue(3)); + + // Check the 3-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_payload_read.offset, 1); + + // Read 5 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 5), SyscallSucceedsWithValue(5)); + + // Check the 5-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_payload_read.offset, 4); +} + +TEST_F(ReadTest, PRead) { + const int file_size = 512; + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size)); + + // Prepare for the read. + const int n_read = 5; + std::vector<char> data(n_read); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read some bytes. + std::vector<char> buf(n_read); + const int offset_read = file_size >> 1; + EXPECT_THAT(pread(fd.get(), buf.data(), n_read, offset_read), + SyscallSucceedsWithValue(n_read)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, offset_read); + EXPECT_EQ(buf, data); +} + +TEST_F(ReadTest, ReadZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Issue the read. + std::vector<char> buf; + EXPECT_THAT(read(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(ReadTest, ReadShort) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the short read. + const int n_read = 5; + std::vector<char> data(n_read >> 1); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(data.size())); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + std::vector<char> short_buf(buf.begin(), buf.begin() + data.size()); + EXPECT_EQ(short_buf, data); +} + +TEST_F(ReadTest, ReadShortEOF) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the short read. + struct fuse_out_header out_header_read = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + const int n_read = 10; + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), SyscallSucceedsWithValue(0)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); +} + +TEST_F(ReadTestSmallMaxRead, ReadSmallMaxRead) { + const int n_fragment = 10; + const int n_read = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read)); + + // Prepare for the read. + std::vector<char> data(size_fragment); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + + for (int i = 0; i < n_fragment; ++i) { + SetServerResponse(FUSE_READ, iov_out_read); + } + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read)); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check each read segment. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, i * size_fragment); + EXPECT_EQ(in_payload_read.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + EXPECT_EQ(std::vector<char>(it, it + size_fragment), data); + } +} + +TEST_F(ReadTestSmallMaxRead, ReadSmallMaxReadShort) { + const int n_fragment = 10; + const int n_read = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read)); + + // Prepare for the read. + std::vector<char> data(size_fragment); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + + for (int i = 0; i < n_fragment - 1; ++i) { + SetServerResponse(FUSE_READ, iov_out_read); + } + + // The last fragment is a short read. + std::vector<char> half_data(data.begin(), data.begin() + (data.size() >> 1)); + struct fuse_out_header out_header_read_short = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header) + + half_data.size()), + }; + auto iov_out_read_short = + FuseGenerateIovecs(out_header_read_short, half_data); + SetServerResponse(FUSE_READ, iov_out_read_short); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read - (data.size() >> 1))); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check each read segment. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, i * size_fragment); + EXPECT_EQ(in_payload_read.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + if (i != n_fragment - 1) { + EXPECT_EQ(std::vector<char>(it, it + data.size()), data); + } else { + EXPECT_EQ(std::vector<char>(it, it + half_data.size()), half_data); + } + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/readdir_test.cc b/test/fuse/linux/readdir_test.cc new file mode 100644 index 000000000..2afb4b062 --- /dev/null +++ b/test/fuse/linux/readdir_test.cc @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <dirent.h> +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <linux/unistd.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +#define FUSE_NAME_OFFSET offsetof(struct fuse_dirent, name) +#define FUSE_DIRENT_ALIGN(x) \ + (((x) + sizeof(uint64_t) - 1) & ~(sizeof(uint64_t) - 1)) +#define FUSE_DIRENT_SIZE(d) FUSE_DIRENT_ALIGN(FUSE_NAME_OFFSET + (d)->namelen) + +namespace gvisor { +namespace testing { + +namespace { + +class ReaddirTest : public FuseTest { + public: + void fill_fuse_dirent(char *buf, const char *name, uint64_t ino) { + size_t namelen = strlen(name); + size_t entlen = FUSE_NAME_OFFSET + namelen; + size_t entlen_padded = FUSE_DIRENT_ALIGN(entlen); + struct fuse_dirent *dirent; + + dirent = reinterpret_cast<struct fuse_dirent *>(buf); + dirent->ino = ino; + dirent->namelen = namelen; + memcpy(dirent->name, name, namelen); + memset(dirent->name + namelen, 0, entlen_padded - entlen); + } + + protected: + const std::string test_dir_name_ = "test_dir"; +}; + +TEST_F(ReaddirTest, SingleEntry) { + const std::string test_dir_path = + JoinPath(mount_point_.path().c_str(), test_dir_name_); + + const uint64_t ino_dir = 1024; + // We need to make sure the test dir is a directory that can be found. + mode_t expected_mode = + S_IFDIR | S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; + struct fuse_attr dir_attr = { + .ino = ino_dir, + .size = 512, + .blocks = 4, + .mode = expected_mode, + .blksize = 4096, + }; + + // We need to make sure the test dir is a directory that can be found. + struct fuse_out_header lookup_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out lookup_payload = { + .nodeid = 1, + .entry_valid = true, + .attr_valid = true, + .attr = dir_attr, + }; + + struct fuse_out_header open_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out open_payload = { + .fh = 1, + }; + auto iov_out = FuseGenerateIovecs(lookup_header, lookup_payload); + SetServerResponse(FUSE_LOOKUP, iov_out); + + iov_out = FuseGenerateIovecs(open_header, open_payload); + SetServerResponse(FUSE_OPENDIR, iov_out); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_dir_path.c_str(), O_RDONLY)); + + // The open command makes two syscalls. Lookup the dir file and open. + // We don't need to inspect those headers in this test. + SkipServerActualRequest(); // LOOKUP. + SkipServerActualRequest(); // OPENDIR. + + // Readdir test code. + std::string dot = "."; + std::string dot_dot = ".."; + std::string test_file = "testFile"; + + // Figure out how many dirents to send over and allocate them appropriately. + // Each dirent has a dynamic name and a static metadata part. The dirent size + // is aligned to being a multiple of 8. + size_t dot_file_dirent_size = + FUSE_DIRENT_ALIGN(dot.length() + FUSE_NAME_OFFSET); + size_t dot_dot_file_dirent_size = + FUSE_DIRENT_ALIGN(dot_dot.length() + FUSE_NAME_OFFSET); + size_t test_file_dirent_size = + FUSE_DIRENT_ALIGN(test_file.length() + FUSE_NAME_OFFSET); + + // Create an appropriately sized payload. + size_t readdir_payload_size = + test_file_dirent_size + dot_file_dirent_size + dot_dot_file_dirent_size; + std::vector<char> readdir_payload_vec(readdir_payload_size); + char *readdir_payload = readdir_payload_vec.data(); + + // Use fake ino for other directories. + fill_fuse_dirent(readdir_payload, dot.c_str(), ino_dir - 2); + fill_fuse_dirent(readdir_payload + dot_file_dirent_size, dot_dot.c_str(), + ino_dir - 1); + fill_fuse_dirent( + readdir_payload + dot_file_dirent_size + dot_dot_file_dirent_size, + test_file.c_str(), ino_dir); + + struct fuse_out_header readdir_header = { + .len = uint32_t(sizeof(struct fuse_out_header) + readdir_payload_size), + }; + struct fuse_out_header readdir_header_break = { + .len = uint32_t(sizeof(struct fuse_out_header)), + }; + + iov_out = FuseGenerateIovecs(readdir_header, readdir_payload_vec); + SetServerResponse(FUSE_READDIR, iov_out); + + iov_out = FuseGenerateIovecs(readdir_header_break); + SetServerResponse(FUSE_READDIR, iov_out); + + std::vector<char> buf(4090, 0); + int nread, off = 0, i = 0; + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceeds()); + for (; off < nread;) { + struct dirent64 *ent = (struct dirent64 *)(buf.data() + off); + off += ent->d_reclen; + switch (i++) { + case 0: + EXPECT_EQ(std::string(ent->d_name), dot); + break; + case 1: + EXPECT_EQ(std::string(ent->d_name), dot_dot); + break; + case 2: + EXPECT_EQ(std::string(ent->d_name), test_file); + break; + } + } + + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(0)); + + SkipServerActualRequest(); // READDIR. + SkipServerActualRequest(); // READDIR with no data. + + // Clean up. + fd.reset(-1); + + struct fuse_in_header in_header; + struct fuse_release_in in_payload; + + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_RELEASEDIR); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/readlink_test.cc b/test/fuse/linux/readlink_test.cc new file mode 100644 index 000000000..2cba8fc23 --- /dev/null +++ b/test/fuse/linux/readlink_test.cc @@ -0,0 +1,85 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReadlinkTest : public FuseTest { + protected: + const std::string test_file_ = "test_file_"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(ReadlinkTest, ReadSymLink) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFLNK | perms_); + + struct fuse_out_header out_header = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)) + + static_cast<uint32_t>(test_file_.length()) + 1, + }; + std::string link = test_file_; + auto iov_out = FuseGenerateIovecs(out_header, link); + SetServerResponse(FUSE_READLINK, iov_out); + const std::string actual_link = + ASSERT_NO_ERRNO_AND_VALUE(ReadLink(symlink_path)); + + struct fuse_in_header in_header; + auto iov_in = FuseGenerateIovecs(in_header); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header)); + EXPECT_EQ(in_header.opcode, FUSE_READLINK); + EXPECT_EQ(0, memcmp(actual_link.c_str(), link.data(), link.size())); + + // next readlink should have link cached, so shouldn't have new request to + // server. + uint32_t recieved_before = GetServerTotalReceivedBytes(); + ASSERT_NO_ERRNO(ReadLink(symlink_path)); + EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before); +} + +TEST_F(ReadlinkTest, NotSymlink) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | perms_); + + std::vector<char> buf(PATH_MAX + 1); + ASSERT_THAT(readlink(test_file_path.c_str(), buf.data(), PATH_MAX), + SyscallFailsWithErrno(EINVAL)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/release_test.cc b/test/fuse/linux/release_test.cc new file mode 100644 index 000000000..b5adb0870 --- /dev/null +++ b/test/fuse/linux/release_test.cc @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReleaseTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; +}; + +TEST_F(ReleaseTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload = { + .fh = 1, + .open_flags = O_RDWR, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path, O_RDWR)); + SkipServerActualRequest(); + ASSERT_THAT(close(fd.release()), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_release_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_RELEASE); + EXPECT_EQ(in_payload.flags, O_RDWR); + EXPECT_EQ(in_payload.fh, 1); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/rmdir_test.cc b/test/fuse/linux/rmdir_test.cc new file mode 100644 index 000000000..e3200e446 --- /dev/null +++ b/test/fuse/linux/rmdir_test.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class RmDirTest : public FuseTest { + protected: + const std::string test_dir_name_ = "test_dir"; + const std::string test_subdir_ = "test_subdir"; + const mode_t test_dir_mode_ = S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(RmDirTest, NormalRmDir) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_name_); + + SetServerInodeLookup(test_dir_name_, test_dir_mode_); + + // RmDir code. + struct fuse_out_header rmdir_header = { + .len = sizeof(struct fuse_out_header), + }; + + auto iov_out = FuseGenerateIovecs(rmdir_header); + SetServerResponse(FUSE_RMDIR, iov_out); + + ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_dirname(test_dir_name_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, actual_dirname); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_RMDIR); + EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_); +} + +TEST_F(RmDirTest, NormalRmDirSubdir) { + SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO); + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_subdir_, test_dir_name_); + SetServerInodeLookup(test_dir_name_, test_dir_mode_); + + // RmDir code. + struct fuse_out_header rmdir_header = { + .len = sizeof(struct fuse_out_header), + }; + + auto iov_out = FuseGenerateIovecs(rmdir_header); + SetServerResponse(FUSE_RMDIR, iov_out); + + ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_dirname(test_dir_name_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, actual_dirname); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_RMDIR); + EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/setstat_test.cc b/test/fuse/linux/setstat_test.cc new file mode 100644 index 000000000..68301c775 --- /dev/null +++ b/test/fuse/linux/setstat_test.cc @@ -0,0 +1,338 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> +#include <utime.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_fd_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class SetStatTest : public FuseFdTest { + public: + void SetUp() override { + FuseFdTest::SetUp(); + test_dir_path_ = JoinPath(mount_point_.path(), test_dir_); + test_file_path_ = JoinPath(mount_point_.path(), test_file_); + } + + protected: + const uint64_t fh = 23; + const std::string test_dir_ = "testdir"; + const std::string test_file_ = "testfile"; + const mode_t test_dir_mode_ = S_IFDIR | S_IRUSR | S_IWUSR | S_IXUSR; + const mode_t test_file_mode_ = S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR; + + std::string test_dir_path_; + std::string test_file_path_; +}; + +TEST_F(SetStatTest, ChmodDir) { + // Set up fixture. + SetServerInodeLookup(test_dir_, test_dir_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + mode_t set_mode = S_IRGRP | S_IWGRP | S_IXGRP; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(set_mode, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(chmod(test_dir_path_.c_str(), set_mode), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_MODE); + EXPECT_EQ(in_payload.mode, S_IFDIR | set_mode); +} + +TEST_F(SetStatTest, ChownDir) { + // Set up fixture. + SetServerInodeLookup(test_dir_, test_dir_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_dir_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(chown(test_dir_path_.c_str(), 1025, 1025), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID); + EXPECT_EQ(in_payload.uid, 1025); + EXPECT_EQ(in_payload.gid, 1025); +} + +TEST_F(SetStatTest, TruncateFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(truncate(test_file_path_.c_str(), 321), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_SIZE); + EXPECT_EQ(in_payload.size, 321); +} + +TEST_F(SetStatTest, UtimeFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + time_t expected_atime = 1597159766, expected_mtime = 1597159765; + struct utimbuf times = { + .actime = expected_atime, + .modtime = expected_mtime, + }; + EXPECT_THAT(utime(test_file_path_.c_str(), ×), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME); + EXPECT_EQ(in_payload.atime, expected_atime); + EXPECT_EQ(in_payload.mtime, expected_mtime); +} + +TEST_F(SetStatTest, UtimesFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_file_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + struct timeval expected_times[2] = { + { + .tv_sec = 1597159766, + .tv_usec = 234945, + }, + { + .tv_sec = 1597159765, + .tv_usec = 232341, + }, + }; + EXPECT_THAT(utimes(test_file_path_.c_str(), expected_times), + SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME); + EXPECT_EQ(in_payload.atime, expected_times[0].tv_sec); + EXPECT_EQ(in_payload.atimensec, expected_times[0].tv_usec * 1000); + EXPECT_EQ(in_payload.mtime, expected_times[1].tv_sec); + EXPECT_EQ(in_payload.mtimensec, expected_times[1].tv_usec * 1000); +} + +TEST_F(SetStatTest, FtruncateFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_file_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(ftruncate(fd.get(), 321), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_SIZE | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.size, 321); +} + +TEST_F(SetStatTest, FchmodFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + mode_t set_mode = S_IROTH | S_IWOTH | S_IXOTH; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(set_mode, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(fchmod(fd.get(), set_mode), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_MODE | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.mode, S_IFREG | set_mode); +} + +TEST_F(SetStatTest, FchownFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(fchown(fd.get(), 1025, 1025), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.uid, 1025); + EXPECT_EQ(in_payload.gid, 1025); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc new file mode 100644 index 000000000..6f032cac1 --- /dev/null +++ b/test/fuse/linux/stat_test.cc @@ -0,0 +1,219 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_fd_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class StatTest : public FuseFdTest { + public: + void SetUp() override { + FuseFdTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path(), test_file_); + } + + protected: + bool StatsAreEqual(struct stat expected, struct stat actual) { + // Device number will be dynamically allocated by kernel, we cannot know in + // advance. + actual.st_dev = expected.st_dev; + return memcmp(&expected, &actual, sizeof(struct stat)) == 0; + } + + const std::string test_file_ = "testfile"; + const mode_t expected_mode = S_IFREG | S_IRUSR | S_IWUSR; + const uint64_t fh = 23; + + std::string test_file_path_; +}; + +TEST_F(StatTest, StatNormal) { + // Set up fixture. + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 1); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallSucceeds()); + + // Check filesystem operation result. + struct stat expected_stat = { + .st_ino = attr.ino, + .st_nlink = attr.nlink, + .st_mode = expected_mode, + .st_uid = attr.uid, + .st_gid = attr.gid, + .st_rdev = attr.rdev, + .st_size = static_cast<off_t>(attr.size), + .st_blksize = attr.blksize, + .st_blocks = static_cast<blkcnt_t>(attr.blocks), + .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime), + .tv_nsec = attr.atimensec}, + .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime), + .tv_nsec = attr.mtimensec}, + .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime), + .tv_nsec = attr.ctimensec}, + }; + EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, StatNotFound) { + // Set up fixture. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), + SyscallFailsWithErrno(ENOENT)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, FstatNormal) { + // Set up fixture. + SetServerInodeLookup(test_file_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(fstat(fd.get(), &stat_buf), SyscallSucceeds()); + + // Check filesystem operation result. + struct stat expected_stat = { + .st_ino = attr.ino, + .st_nlink = attr.nlink, + .st_mode = expected_mode, + .st_uid = attr.uid, + .st_gid = attr.gid, + .st_rdev = attr.rdev, + .st_size = static_cast<off_t>(attr.size), + .st_blksize = attr.blksize, + .st_blocks = static_cast<blkcnt_t>(attr.blocks), + .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime), + .tv_nsec = attr.atimensec}, + .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime), + .tv_nsec = attr.mtimensec}, + .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime), + .tv_nsec = attr.ctimensec}, + }; + EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, StatByFileHandle) { + // Set up fixture. + SetServerInodeLookup(test_file_, expected_mode, 0); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2, 0); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + std::vector<char> buf(1); + // Since this is an empty file, it won't issue FUSE_READ. But a FUSE_GETATTR + // will be issued before read completes. + EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, FUSE_GETATTR_FH); + EXPECT_EQ(in_payload.fh, fh); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/symlink_test.cc b/test/fuse/linux/symlink_test.cc new file mode 100644 index 000000000..2c3a52987 --- /dev/null +++ b/test/fuse/linux/symlink_test.cc @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class SymlinkTest : public FuseTest { + protected: + const std::string target_file_ = "target_file_"; + const std::string symlink_ = "symlink_"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(SymlinkTest, CreateSymLink) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), symlink_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFLNK | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SYMLINK, iov_out); + ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()), + SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_target_file(target_file_.length() + 1); + std::vector<char> actual_symlink(symlink_.length() + 1); + auto iov_in = + FuseGenerateIovecs(in_header, actual_symlink, actual_target_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + symlink_.length() + target_file_.length() + 2); + EXPECT_EQ(in_header.opcode, FUSE_SYMLINK); + EXPECT_EQ(std::string(actual_target_file.data()), target_file_); + EXPECT_EQ(std::string(actual_symlink.data()), symlink_); +} + +TEST_F(SymlinkTest, FileTypeError) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), symlink_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SYMLINK, iov_out); + ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/unlink_test.cc b/test/fuse/linux/unlink_test.cc new file mode 100644 index 000000000..13efbf7c7 --- /dev/null +++ b/test/fuse/linux/unlink_test.cc @@ -0,0 +1,107 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class UnlinkTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; + const std::string test_subdir_ = "test_subdir"; +}; + +TEST_F(UnlinkTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds()); + struct fuse_in_header in_header; + std::vector<char> unlinked_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, unlinked_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_UNLINK); + EXPECT_EQ(std::string(unlinked_file.data()), test_file_); +} + +TEST_F(UnlinkTest, RegularFileSubDir) { + SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO); + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_subdir_, test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds()); + struct fuse_in_header in_header; + std::vector<char> unlinked_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, unlinked_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_UNLINK); + EXPECT_EQ(std::string(unlinked_file.data()), test_file_); +} + +TEST_F(UnlinkTest, NoFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallFailsWithErrno(ENOENT)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/write_test.cc b/test/fuse/linux/write_test.cc new file mode 100644 index 000000000..1a62beb96 --- /dev/null +++ b/test/fuse/linux/write_test.cc @@ -0,0 +1,303 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class WriteTest : public FuseTest { + void SetUp() override { + FuseTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + // TearDown overrides the parent's function + // to skip checking the unconsumed release request at the end. + void TearDown() override { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + const uint64_t test_fh_ = 1; + const uint32_t open_flag_ = O_RDWR; + + std::string test_file_path_; + + PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path, + uint64_t size = 512) { + SetServerInodeLookup(test_file_, test_file_mode_, size); + + struct fuse_out_header out_header_open = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload_open = { + .fh = test_fh_, + .open_flags = open_flag_, + }; + auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open); + SetServerResponse(FUSE_OPEN, iov_out_open); + + auto res = Open(path.c_str(), open_flag_); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; + } +}; + +class WriteTestSmallMaxWrite : public WriteTest { + void SetUp() override { + MountFuse(); + SetUpFuseServer(&fuse_init_payload); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + protected: + const static uint32_t max_write_ = 4096; + constexpr static struct fuse_init_out fuse_init_payload = { + .major = 7, + .max_write = max_write_, + }; + + const uint32_t size_fragment = max_write_; +}; + +TEST_F(WriteTest, WriteNormal) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_write, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_write)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteShort) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10, n_written = 5; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_written, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_written)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteShortZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = 0, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), SyscallFailsWithErrno(EIO)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Issue the write. + std::vector<char> buf(0); + EXPECT_THAT(write(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(WriteTest, PWrite) { + const int file_size = 512; + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_write, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + const int offset_write = file_size >> 1; + EXPECT_THAT(pwrite(fd.get(), buf.data(), n_write, offset_write), + SyscallSucceedsWithValue(n_write)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, offset_write); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTestSmallMaxWrite, WriteSmallMaxWrie) { + const int n_fragment = 10; + const int n_write = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_write)); + + // Prepare for the write. + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = size_fragment, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + + for (int i = 0; i < n_fragment; ++i) { + SetServerResponse(FUSE_WRITE, iov_out_write); + } + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_write)); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(size_fragment); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, i * size_fragment); + EXPECT_EQ(in_payload_write.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + EXPECT_EQ(std::vector<char>(it, it + size_fragment), payload_buf); + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/image/image_test.go b/test/image/image_test.go index 3e4321480..968e62f63 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -22,6 +22,7 @@ package image import ( + "context" "flag" "fmt" "io/ioutil" @@ -36,12 +37,20 @@ import ( "gvisor.dev/gvisor/pkg/test/testutil" ) +// defaultWait defines how long to wait for progress. +// +// See BUILD: This is at least a "large" test, so allow up to 1 minute for any +// given "wait" step. Note that all tests are run in parallel, which may cause +// individual slow-downs (but a huge speed-up in aggregate). +const defaultWait = time.Minute + func TestHelloWorld(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Run the basic container. - out, err := d.Run(dockerutil.RunOpts{ + out, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "echo", "Hello world!") if err != nil { @@ -54,8 +63,8 @@ func TestHelloWorld(t *testing.T) { } } -func runHTTPRequest(port int) error { - url := fmt.Sprintf("http://localhost:%d/not-found", port) +func runHTTPRequest(ip string, port int) error { + url := fmt.Sprintf("http://%s:%d/not-found", ip, port) resp, err := http.Get(url) if err != nil { return fmt.Errorf("error reaching http server: %v", err) @@ -64,7 +73,7 @@ func runHTTPRequest(port int) error { return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) } - url = fmt.Sprintf("http://localhost:%d/latin10k.txt", port) + url = fmt.Sprintf("http://%s:%d/latin10k.txt", ip, port) resp, err = http.Get(url) if err != nil { return fmt.Errorf("Error reaching http server: %v", err) @@ -86,13 +95,13 @@ func runHTTPRequest(port int) error { return nil } -func testHTTPServer(t *testing.T, port int) { +func testHTTPServer(t *testing.T, ip string, port int) { const requests = 10 ch := make(chan error, requests) for i := 0; i < requests; i++ { go func() { start := time.Now() - err := runHTTPRequest(port) + err := runHTTPRequest(ip, port) log.Printf("Response time %v: %v", time.Since(start).String(), err) ch <- err }() @@ -101,73 +110,78 @@ func testHTTPServer(t *testing.T, port int) { for i := 0; i < requests; i++ { err := <-ch if err != nil { - t.Errorf("testHTTPServer(%d) failed: %v", port, err) + t.Errorf("testHTTPServer(%s, %d) failed: %v", ip, port, err) } } } func TestHttpd(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. + port := 80 opts := dockerutil.RunOpts{ Image: "basic/httpd", - Ports: []int{80}, + Ports: []int{port}, } d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt") - if err := d.Spawn(opts); err != nil { + if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 80 is mapped to. - port, err := d.FindPort(80) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } - testHTTPServer(t, port) + testHTTPServer(t, ip.String(), port) } func TestNginx(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. + port := 80 opts := dockerutil.RunOpts{ Image: "basic/nginx", - Ports: []int{80}, + Ports: []int{port}, } d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt") - if err := d.Spawn(opts); err != nil { + if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 80 is mapped to. - port, err := d.FindPort(80) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } - testHTTPServer(t, port) + testHTTPServer(t, ip.String(), port) } func TestMysql(t *testing.T) { - server := dockerutil.MakeDocker(t) - defer server.CleanUp() + ctx := context.Background() + server := dockerutil.MakeContainer(ctx, t) + defer server.CleanUp(ctx) // Start the container. - if err := server.Spawn(dockerutil.RunOpts{ + if err := server.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/mysql", Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"}, }); err != nil { @@ -175,61 +189,58 @@ func TestMysql(t *testing.T) { } // Wait until it's up and running. - if _, err := server.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil { + if _, err := server.WaitForOutput(ctx, "port: 3306 MySQL Community Server", defaultWait); err != nil { t.Fatalf("WaitForOutput() timeout: %v", err) } // Generate the client and copy in the SQL payload. - client := dockerutil.MakeDocker(t) - defer client.CleanUp() + client := dockerutil.MakeContainer(ctx, t) + defer client.CleanUp(ctx) // Tell mysql client to connect to the server and execute the file in // verbose mode to verify the output. opts := dockerutil.RunOpts{ Image: "basic/mysql", - Links: []dockerutil.Link{ - { - Source: server, - Target: "mysql", - }, - }, + Links: []string{server.MakeLink("mysql")}, } client.CopyFiles(&opts, "/sql", "test/image/mysql.sql") - if _, err := client.Run(opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil { + if _, err := client.Run(ctx, opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil { t.Fatalf("docker run failed: %v", err) } // Ensure file executed to the end and shutdown mysql. - if _, err := server.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil { + if _, err := server.WaitForOutput(ctx, "mysqld: Shutdown complete", defaultWait); err != nil { t.Fatalf("WaitForOutput() timeout: %v", err) } } func TestTomcat(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the server. - if err := d.Spawn(dockerutil.RunOpts{ + port := 8080 + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/tomcat", - Ports: []int{8080}, + Ports: []int{port}, }); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) + url := fmt.Sprintf("http://%s:%d", ip.String(), port) resp, err := http.Get(url) if err != nil { t.Errorf("Error reaching http server: %v", err) @@ -240,32 +251,34 @@ func TestTomcat(t *testing.T) { } func TestRuby(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Execute the ruby workload. + port := 8080 opts := dockerutil.RunOpts{ Image: "basic/ruby", - Ports: []int{8080}, + Ports: []int{port}, } d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh") - if err := d.Spawn(opts, "/src/ruby.sh"); err != nil { + if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running, 'gem install' can take some time. - if err := testutil.WaitForHTTP(port, 1*time.Minute); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, time.Minute); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) + url := fmt.Sprintf("http://%s:%d", ip.String(), port) resp, err := http.Get(url) if err != nil { t.Errorf("error reaching http server: %v", err) @@ -283,20 +296,21 @@ func TestRuby(t *testing.T) { } func TestStdio(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) wantStdout := "hello stdout" wantStderr := "bonjour stderr" cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "/bin/sh", "-c", cmd); err != nil { t.Fatalf("docker run failed: %v", err) } for _, want := range []string{wantStdout, wantStderr} { - if _, err := d.WaitForOutput(want, 5*time.Second); err != nil { + if _, err := d.WaitForOutput(ctx, want, defaultWait); err != nil { t.Fatalf("docker didn't get output %q : %v", want, err) } } diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 3e29ca90d..66453772a 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -9,6 +9,7 @@ go_library( "filter_input.go", "filter_output.go", "iptables.go", + "iptables_unsafe.go", "iptables_util.go", "nat.go", ], @@ -20,6 +21,7 @@ go_library( go_test( name = "iptables_test", + size = "large", srcs = [ "iptables_test.go", ], diff --git a/test/iptables/README.md b/test/iptables/README.md index b9f44bd40..28ab195ca 100644 --- a/test/iptables/README.md +++ b/test/iptables/README.md @@ -1,6 +1,6 @@ # iptables Tests -iptables tests are run via `scripts/iptables_test.sh`. +iptables tests are run via `make iptables-tests`. iptables requires raw socket support, so you must add the `--net-raw=true` flag to `/etc/docker/daemon.json` in order to use it. diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 872021358..b45d448b8 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -15,6 +15,7 @@ package iptables import ( + "context" "errors" "fmt" "net" @@ -25,7 +26,6 @@ const ( dropPort = 2401 acceptPort = 2402 sendloopDuration = 2 * time.Second - network = "udp4" chainName = "foochain" ) @@ -54,7 +54,7 @@ func init() { } // FilterInputDropUDP tests that we can drop UDP traffic. -type FilterInputDropUDP struct{} +type FilterInputDropUDP struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropUDP) Name() string { @@ -62,15 +62,17 @@ func (FilterInputDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { +func (FilterInputDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -80,12 +82,12 @@ func (FilterInputDropUDP) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic. -type FilterInputDropOnlyUDP struct{} +type FilterInputDropOnlyUDP struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropOnlyUDP) Name() string { @@ -93,13 +95,13 @@ func (FilterInputDropOnlyUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { +func (FilterInputDropOnlyUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } // Listen for a TCP connection, which should be allowed. - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return fmt.Errorf("failed to establish a connection %v", err) } @@ -107,14 +109,14 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error { +func (FilterInputDropOnlyUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Try to establish a TCP connection with the container, which should // succeed. - return connectTCP(ip, acceptPort, sendloopDuration) + return connectTCP(ctx, ip, acceptPort) } // FilterInputDropUDPPort tests that we can drop UDP traffic by port. -type FilterInputDropUDPPort struct{} +type FilterInputDropUDPPort struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropUDPPort) Name() string { @@ -122,15 +124,17 @@ func (FilterInputDropUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterInputDropUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -140,13 +144,13 @@ func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port // doesn't drop packets on other ports. -type FilterInputDropDifferentUDPPort struct{} +type FilterInputDropDifferentUDPPort struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropDifferentUDPPort) Name() string { @@ -154,13 +158,13 @@ func (FilterInputDropDifferentUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterInputDropDifferentUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for UDP packets on another port. - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -168,12 +172,12 @@ func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDropDifferentUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPDestPort struct{} +type FilterInputDropTCPDestPort struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropTCPDestPort) Name() string { @@ -181,33 +185,36 @@ func (FilterInputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterInputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error { +func (FilterInputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. - for start := time.Now(); time.Since(start) < sendloopDuration; { - if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil { - return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) - } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) } - return nil } // FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPSrcPort struct{} +type FilterInputDropTCPSrcPort struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropTCPSrcPort) Name() string { @@ -215,34 +222,37 @@ func (FilterInputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error { +func (FilterInputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Drop anything from an ephemeral port. - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error { +func (FilterInputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. - for start := time.Now(); time.Since(start) < sendloopDuration; { - if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil { - return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) - } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) } - return nil } // FilterInputDropAll tests that we can drop all traffic to the INPUT chain. -type FilterInputDropAll struct{} +type FilterInputDropAll struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropAll) Name() string { @@ -250,15 +260,17 @@ func (FilterInputDropAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropAll) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil { +func (FilterInputDropAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-j", "DROP"); err != nil { return err } // Listen for all packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets should have been dropped, but got a packet") - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -268,15 +280,15 @@ func (FilterInputDropAll) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropAll) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputMultiUDPRules verifies that multiple UDP rules are applied // correctly. This has the added benefit of testing whether we're serializing // rules correctly -- if we do it incorrectly, the iptables tool will // misunderstand and save the wrong tables. -type FilterInputMultiUDPRules struct{} +type FilterInputMultiUDPRules struct{ baseCase } // Name implements TestCase.Name. func (FilterInputMultiUDPRules) Name() string { @@ -284,24 +296,24 @@ func (FilterInputMultiUDPRules) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { +func (FilterInputMultiUDPRules) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, {"-L"}, } - return filterTableRules(rules) + return filterTableRules(ipv6, rules) } // LocalAction implements TestCase.LocalAction. -func (FilterInputMultiUDPRules) LocalAction(ip net.IP) error { +func (FilterInputMultiUDPRules) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be // specified. -type FilterInputRequireProtocolUDP struct{} +type FilterInputRequireProtocolUDP struct{ baseCase } // Name implements TestCase.Name. func (FilterInputRequireProtocolUDP) Name() string { @@ -309,20 +321,20 @@ func (FilterInputRequireProtocolUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { +func (FilterInputRequireProtocolUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { return errors.New("expected iptables to fail with out \"-p udp\", but succeeded") } return nil } -func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error { +func (FilterInputRequireProtocolUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputCreateUserChain tests chain creation. -type FilterInputCreateUserChain struct{} +type FilterInputCreateUserChain struct{ baseCase } // Name implements TestCase.Name. func (FilterInputCreateUserChain) Name() string { @@ -330,24 +342,24 @@ func (FilterInputCreateUserChain) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { +func (FilterInputCreateUserChain) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ // Create a chain. {"-N", chainName}, // Add a simple rule to the chain. {"-A", chainName, "-j", "DROP"}, } - return filterTableRules(rules) + return filterTableRules(ipv6, rules) } // LocalAction implements TestCase.LocalAction. -func (FilterInputCreateUserChain) LocalAction(ip net.IP) error { +func (FilterInputCreateUserChain) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputDefaultPolicyAccept tests the default ACCEPT policy. -type FilterInputDefaultPolicyAccept struct{} +type FilterInputDefaultPolicyAccept struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDefaultPolicyAccept) Name() string { @@ -355,21 +367,21 @@ func (FilterInputDefaultPolicyAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error { +func (FilterInputDefaultPolicyAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Set the default policy to accept, then receive a packet. - if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + if err := filterTable(ipv6, "-P", "INPUT", "ACCEPT"); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDefaultPolicyDrop tests the default DROP policy. -type FilterInputDefaultPolicyDrop struct{} +type FilterInputDefaultPolicyDrop struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDefaultPolicyDrop) Name() string { @@ -377,15 +389,17 @@ func (FilterInputDefaultPolicyDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { - if err := filterTable("-P", "INPUT", "DROP"); err != nil { +func (FilterInputDefaultPolicyDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-P", "INPUT", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -395,13 +409,13 @@ func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes // the underflow rule (i.e. default policy) to be executed. -type FilterInputReturnUnderflow struct{} +type FilterInputReturnUnderflow struct{ containerCase } // Name implements TestCase.Name. func (FilterInputReturnUnderflow) Name() string { @@ -409,7 +423,7 @@ func (FilterInputReturnUnderflow) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { +func (FilterInputReturnUnderflow) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Add a RETURN rule followed by an unconditional accept, and set the // default policy to DROP. rules := [][]string{ @@ -417,22 +431,22 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { {"-A", "INPUT", "-j", "DROP"}, {"-P", "INPUT", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // We should receive packets, as the RETURN rule will trigger the default // ACCEPT policy. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputReturnUnderflow) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputSerializeJump verifies that we can serialize jumps. -type FilterInputSerializeJump struct{} +type FilterInputSerializeJump struct{ baseCase } // Name implements TestCase.Name. func (FilterInputSerializeJump) Name() string { @@ -440,24 +454,24 @@ func (FilterInputSerializeJump) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSerializeJump) ContainerAction(ip net.IP) error { +func (FilterInputSerializeJump) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Write a JUMP rule, the serialize it with `-L`. rules := [][]string{ {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, {"-L"}, } - return filterTableRules(rules) + return filterTableRules(ipv6, rules) } // LocalAction implements TestCase.LocalAction. -func (FilterInputSerializeJump) LocalAction(ip net.IP) error { +func (FilterInputSerializeJump) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputJumpBasic jumps to a chain and executes a rule there. -type FilterInputJumpBasic struct{} +type FilterInputJumpBasic struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpBasic) Name() string { @@ -465,28 +479,28 @@ func (FilterInputJumpBasic) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { +func (FilterInputJumpBasic) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-P", "INPUT", "DROP"}, {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, {"-A", chainName, "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBasic) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpBasic) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputJumpReturn jumps, returns, and executes a rule. -type FilterInputJumpReturn struct{} +type FilterInputJumpReturn struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpReturn) Name() string { @@ -494,7 +508,7 @@ func (FilterInputJumpReturn) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { +func (FilterInputJumpReturn) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-P", "INPUT", "ACCEPT"}, @@ -502,21 +516,21 @@ func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { {"-A", chainName, "-j", "RETURN"}, {"-A", chainName, "-j", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturn) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpReturn) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. -type FilterInputJumpReturnDrop struct{} +type FilterInputJumpReturnDrop struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpReturnDrop) Name() string { @@ -524,21 +538,23 @@ func (FilterInputJumpReturnDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { +func (FilterInputJumpReturnDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, {"-A", "INPUT", "-j", "DROP"}, {"-A", chainName, "-j", "RETURN"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -548,12 +564,12 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputJumpReturnDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. -type FilterInputJumpBuiltin struct{} +type FilterInputJumpBuiltin struct{ baseCase } // Name implements TestCase.Name. func (FilterInputJumpBuiltin) Name() string { @@ -561,21 +577,21 @@ func (FilterInputJumpBuiltin) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil { +func (FilterInputJumpBuiltin) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-j", "OUTPUT"); err == nil { return fmt.Errorf("iptables should be unable to jump to a built-in chain") } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error { +func (FilterInputJumpBuiltin) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputJumpTwice jumps twice, then returns twice and executes a rule. -type FilterInputJumpTwice struct{} +type FilterInputJumpTwice struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpTwice) Name() string { @@ -583,7 +599,7 @@ func (FilterInputJumpTwice) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { +func (FilterInputJumpTwice) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { const chainName2 = chainName + "2" rules := [][]string{ {"-P", "INPUT", "DROP"}, @@ -593,23 +609,23 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { {"-A", chainName, "-j", chainName2}, {"-A", "INPUT", "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // UDP packets should jump and return twice, eventually hitting the // ACCEPT rule. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpTwice) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpTwice) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDestination verifies that we can filter packets via `-d // <ipaddr>`. -type FilterInputDestination struct{} +type FilterInputDestination struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDestination) Name() string { @@ -617,8 +633,8 @@ func (FilterInputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDestination) ContainerAction(ip net.IP) error { - addrs, err := localAddrs() +func (FilterInputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + addrs, err := localAddrs(ipv6) if err != nil { return err } @@ -629,21 +645,21 @@ func (FilterInputDestination) ContainerAction(ip net.IP) error { for _, addr := range addrs { rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"}) } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDestination) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputInvertDestination verifies that we can filter packets via `! -d // <ipaddr>`. -type FilterInputInvertDestination struct{} +type FilterInputInvertDestination struct{ containerCase } // Name implements TestCase.Name. func (FilterInputInvertDestination) Name() string { @@ -651,28 +667,28 @@ func (FilterInputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertDestination) ContainerAction(ip net.IP) error { +func (FilterInputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ {"-P", "INPUT", "DROP"}, - {"-A", "INPUT", "!", "-d", localIP, "-j", "ACCEPT"}, + {"-A", "INPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertDestination) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputSource verifies that we can filter packets via `-s // <ipaddr>`. -type FilterInputSource struct{} +type FilterInputSource struct{ containerCase } // Name implements TestCase.Name. func (FilterInputSource) Name() string { @@ -680,28 +696,28 @@ func (FilterInputSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSource) ContainerAction(ip net.IP) error { +func (FilterInputSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets from this // machine. rules := [][]string{ {"-P", "INPUT", "DROP"}, {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputSource) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputInvertSource verifies that we can filter packets via `! -s // <ipaddr>`. -type FilterInputInvertSource struct{} +type FilterInputInvertSource struct{ containerCase } // Name implements TestCase.Name. func (FilterInputInvertSource) Name() string { @@ -709,21 +725,21 @@ func (FilterInputInvertSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertSource) ContainerAction(ip net.IP) error { +func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ {"-P", "INPUT", "DROP"}, - {"-A", "INPUT", "!", "-s", localIP, "-j", "ACCEPT"}, + {"-A", "INPUT", "!", "-s", localIP(ipv6), "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertSource) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index ba0d6fc29..32bf2a992 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -15,6 +15,8 @@ package iptables import ( + "context" + "errors" "fmt" "net" ) @@ -44,7 +46,7 @@ func init() { // FilterOutputDropTCPDestPort tests that connections are not accepted on // specified source ports. -type FilterOutputDropTCPDestPort struct{} +type FilterOutputDropTCPDestPort struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPDestPort) Name() string { @@ -52,22 +54,28 @@ func (FilterOutputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { +func (FilterOutputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) } @@ -76,7 +84,7 @@ func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error { // FilterOutputDropTCPSrcPort tests that connections are not accepted on // specified source ports. -type FilterOutputDropTCPSrcPort struct{} +type FilterOutputDropTCPSrcPort struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPSrcPort) Name() string { @@ -84,22 +92,28 @@ func (FilterOutputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterOutputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) } @@ -107,7 +121,7 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { } // FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. -type FilterOutputAcceptTCPOwner struct{} +type FilterOutputAcceptTCPOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputAcceptTCPOwner) Name() string { @@ -115,22 +129,22 @@ func (FilterOutputAcceptTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { +func (FilterOutputAcceptTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputAcceptTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. -type FilterOutputDropTCPOwner struct{} +type FilterOutputDropTCPOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPOwner) Name() string { @@ -138,22 +152,28 @@ func (FilterOutputDropTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { +func (FilterOutputDropTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) } @@ -161,7 +181,7 @@ func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { } // FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. -type FilterOutputAcceptUDPOwner struct{} +type FilterOutputAcceptUDPOwner struct{ localCase } // Name implements TestCase.Name. func (FilterOutputAcceptUDPOwner) Name() string { @@ -169,23 +189,23 @@ func (FilterOutputAcceptUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { +func (FilterOutputAcceptUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Send UDP packets on acceptPort. - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error { +func (FilterOutputAcceptUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. -type FilterOutputDropUDPOwner struct{} +type FilterOutputDropUDPOwner struct{ localCase } // Name implements TestCase.Name. func (FilterOutputDropUDPOwner) Name() string { @@ -193,20 +213,24 @@ func (FilterOutputDropUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { +func (FilterOutputDropUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } // Send UDP packets on dropPort. - return sendUDPLoop(ip, dropPort, sendloopDuration) + return sendUDPLoop(ctx, ip, dropPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { +func (FilterOutputDropUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets should not be received") + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -214,7 +238,7 @@ func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { // FilterOutputOwnerFail tests that without uid/gid option, owner rule // will fail. -type FilterOutputOwnerFail struct{} +type FilterOutputOwnerFail struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputOwnerFail) Name() string { @@ -222,8 +246,8 @@ func (FilterOutputOwnerFail) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { +func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { return fmt.Errorf("Invalid argument") } @@ -231,13 +255,13 @@ func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputOwnerFail) LocalAction(ip net.IP) error { +func (FilterOutputOwnerFail) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // no-op. return nil } // FilterOutputAcceptGIDOwner tests that TCP connections from gid owner are accepted. -type FilterOutputAcceptGIDOwner struct{} +type FilterOutputAcceptGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputAcceptGIDOwner) Name() string { @@ -245,22 +269,22 @@ func (FilterOutputAcceptGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { +func (FilterOutputAcceptGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptGIDOwner) LocalAction(ip net.IP) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputAcceptGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputDropGIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputDropGIDOwner struct{} +type FilterOutputDropGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropGIDOwner) Name() string { @@ -268,22 +292,28 @@ func (FilterOutputDropGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { +func (FilterOutputDropGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropGIDOwner) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -291,7 +321,7 @@ func (FilterOutputDropGIDOwner) LocalAction(ip net.IP) error { } // FilterOutputInvertGIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputInvertGIDOwner struct{} +type FilterOutputInvertGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertGIDOwner) Name() string { @@ -299,26 +329,32 @@ func (FilterOutputInvertGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP) error { +func (FilterOutputInvertGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInvertGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -326,7 +362,7 @@ func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP) error { } // FilterOutputInvertUIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputInvertUIDOwner struct{} +type FilterOutputInvertUIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertUIDOwner) Name() string { @@ -334,27 +370,27 @@ func (FilterOutputInvertUIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP) error { +func (FilterOutputInvertUIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDOwner) LocalAction(ip net.IP) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputInvertUIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputInvertUIDAndGIDOwner tests that TCP connections from uid and gid // owner are dropped. -type FilterOutputInvertUIDAndGIDOwner struct{} +type FilterOutputInvertUIDAndGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertUIDAndGIDOwner) Name() string { @@ -362,26 +398,32 @@ func (FilterOutputInvertUIDAndGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP) error { +func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -390,7 +432,7 @@ func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP) error { // FilterOutputDestination tests that we can selectively allow packets to // certain destinations. -type FilterOutputDestination struct{} +type FilterOutputDestination struct{ localCase } // Name implements TestCase.Name. func (FilterOutputDestination) Name() string { @@ -398,26 +440,26 @@ func (FilterOutputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDestination) ContainerAction(ip net.IP) error { +func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDestination) LocalAction(ip net.IP) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInvertDestination tests that we can selectively allow packets // not headed for a particular destination. -type FilterOutputInvertDestination struct{} +type FilterOutputInvertDestination struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInvertDestination) Name() string { @@ -425,26 +467,26 @@ func (FilterOutputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error { +func (FilterOutputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ - {"-A", "OUTPUT", "!", "-d", localIP, "-j", "ACCEPT"}, + {"-A", "OUTPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertDestination) LocalAction(ip net.IP) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceAccept tests that packets are sent via interface // matching the iptables rule. -type FilterOutputInterfaceAccept struct{} +type FilterOutputInterfaceAccept struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceAccept) Name() string { @@ -452,26 +494,26 @@ func (FilterOutputInterfaceAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error { +func (FilterOutputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") } - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceAccept) LocalAction(ip net.IP) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceDrop tests that packets are not sent via interface // matching the iptables rule. -type FilterOutputInterfaceDrop struct{} +type FilterOutputInterfaceDrop struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceDrop) Name() string { @@ -479,22 +521,26 @@ func (FilterOutputInterfaceDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error { +func (FilterOutputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") } - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error { - if err := listenUDP(acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -502,7 +548,7 @@ func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error { // FilterOutputInterface tests that packets are sent via interface which is // not matching the interface name in the iptables rule. -type FilterOutputInterface struct{} +type FilterOutputInterface struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterface) Name() string { @@ -510,22 +556,22 @@ func (FilterOutputInterface) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterface) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { +func (FilterOutputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterface) LocalAction(ip net.IP) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceBeginsWith tests that packets are not sent via an // interface which begins with the given interface name. -type FilterOutputInterfaceBeginsWith struct{} +type FilterOutputInterfaceBeginsWith struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceBeginsWith) Name() string { @@ -533,18 +579,22 @@ func (FilterOutputInterfaceBeginsWith) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { +func (FilterOutputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error { - if err := listenUDP(acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -552,7 +602,7 @@ func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error { // FilterOutputInterfaceInvertDrop tests that we selectively do not send // packets via interface not matching the interface name. -type FilterOutputInterfaceInvertDrop struct{} +type FilterOutputInterfaceInvertDrop struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInterfaceInvertDrop) Name() string { @@ -560,22 +610,28 @@ func (FilterOutputInterfaceInvertDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { +func (FilterOutputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -584,7 +640,7 @@ func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error { // FilterOutputInterfaceInvertAccept tests that we can selectively send packets // not matching the specific outgoing interface. -type FilterOutputInterfaceInvertAccept struct{} +type FilterOutputInterfaceInvertAccept struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInterfaceInvertAccept) Name() string { @@ -592,16 +648,16 @@ func (FilterOutputInterfaceInvertAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { +func (FilterOutputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go index 16cb4f4da..c2a03f54c 100644 --- a/test/iptables/iptables.go +++ b/test/iptables/iptables.go @@ -16,6 +16,7 @@ package iptables import ( + "context" "fmt" "net" "time" @@ -29,7 +30,11 @@ const IPExchangePort = 2349 const TerminalStatement = "Finished!" // TestTimeout is the timeout used for all tests. -const TestTimeout = 10 * time.Minute +const TestTimeout = 10 * time.Second + +// NegativeTimeout is the time tests should wait to establish the negative +// case, i.e. that connections are not made. +const NegativeTimeout = 2 * time.Second // A TestCase contains one action to run in the container and one to run // locally. The actions run concurrently and each must succeed for the test @@ -40,10 +45,60 @@ type TestCase interface { // ContainerAction runs inside the container. It receives the IP of the // local process. - ContainerAction(ip net.IP) error + ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error // LocalAction runs locally. It receives the IP of the container. - LocalAction(ip net.IP) error + LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error + + // ContainerSufficient indicates whether ContainerAction's return value + // alone indicates whether the test succeeded. + ContainerSufficient() bool + + // LocalSufficient indicates whether LocalAction's return value alone + // indicates whether the test succeeded. + LocalSufficient() bool +} + +// baseCase provides defaults for ContainerSufficient and LocalSufficient when +// both actions are required to finish. +type baseCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (baseCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (baseCase) LocalSufficient() bool { + return false +} + +// localCase provides defaults for ContainerSufficient and LocalSufficient when +// only the local action is required to finish. +type localCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (localCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (localCase) LocalSufficient() bool { + return true +} + +// containerCase provides defaults for ContainerSufficient and LocalSufficient +// when only the container action is required to finish. +type containerCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (containerCase) ContainerSufficient() bool { + return true +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (containerCase) LocalSufficient() bool { + return false } // Tests maps test names to TestCase. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 340f9426e..834f7615f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -15,8 +15,12 @@ package iptables import ( + "context" + "errors" "fmt" "net" + "reflect" + "sync" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -33,12 +37,40 @@ import ( // Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or // to stderr. func singleTest(t *testing.T, test TestCase) { + for _, tc := range []bool{false, true} { + subtest := "IPv4" + if tc { + subtest = "IPv6" + } + t.Run(subtest, func(t *testing.T) { + iptablesTest(t, test, tc) + }) + } +} + +func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + // Wait for the local and container goroutines to finish. + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + d := dockerutil.MakeContainer(ctx, t) + defer func() { + if logs, err := d.Logs(context.Background()); err != nil { + t.Logf("Failed to retrieve container logs.") + } else { + t.Logf("=== Container logs: ===\n%s", logs) + } + // Use a new context, as cleanup should run even when we + // timeout. + d.CleanUp(context.Background()) + }() // Create and start the container. opts := dockerutil.RunOpts{ @@ -46,12 +78,16 @@ func singleTest(t *testing.T, test TestCase) { CapAdd: []string{"NET_ADMIN"}, } d.CopyFiles(&opts, "/runner", "test/iptables/runner/runner") - if err := d.Spawn(opts, "/runner/runner", "-name", test.Name()); err != nil { + args := []string{"/runner/runner", "-name", test.Name()} + if ipv6 { + args = append(args, "-ipv6") + } + if err := d.Spawn(ctx, opts, args...); err != nil { t.Fatalf("docker run failed: %v", err) } // Get the container IP. - ip, err := d.FindIP() + ip, err := d.FindIP(ctx, ipv6) if err != nil { t.Fatalf("failed to get container IP: %v", err) } @@ -62,15 +98,44 @@ func singleTest(t *testing.T, test TestCase) { } // Run our side of the test. - if err := test.LocalAction(ip); err != nil { - t.Fatalf("LocalAction failed: %v", err) - } - - // Wait for the final statement. This structure has the side effect - // that all container logs will appear within the individual test - // context. - if _, err := d.WaitForOutput(TerminalStatement, TestTimeout); err != nil { - t.Fatalf("test failed: %v", err) + errCh := make(chan error, 2) + wg.Add(1) + go func() { + defer wg.Done() + if err := test.LocalAction(ctx, ip, ipv6); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("LocalAction failed: %v", err) + } else { + errCh <- nil + } + if test.LocalSufficient() { + errCh <- nil + } + }() + + // Run the container side. + wg.Add(1) + go func() { + defer wg.Done() + // Wait for the final statement. This structure has the side + // effect that all container logs will appear within the + // individual test context. + if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("ContainerAction failed: %v", err) + } else { + errCh <- nil + } + if test.ContainerSufficient() { + errCh <- nil + } + }() + + for i := 0; i < 2; i++ { + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } } } @@ -83,7 +148,7 @@ func sendIP(ip net.IP) error { // The container may not be listening when we first connect, so retry // upon error. cb := func() error { - c, err := net.DialTCP("tcp4", nil, &contAddr) + c, err := net.DialTCP("tcp", nil, &contAddr) conn = c return err } @@ -244,11 +309,11 @@ func TestInputInvertDestination(t *testing.T) { singleTest(t, FilterInputInvertDestination{}) } -func TestOutputDestination(t *testing.T) { +func TestFilterOutputDestination(t *testing.T) { singleTest(t, FilterOutputDestination{}) } -func TestOutputInvertDestination(t *testing.T) { +func TestFilterOutputInvertDestination(t *testing.T) { singleTest(t, FilterOutputInvertDestination{}) } @@ -260,6 +325,13 @@ func TestNATPreRedirectTCPPort(t *testing.T) { singleTest(t, NATPreRedirectTCPPort{}) } +func TestNATPreRedirectTCPOutgoing(t *testing.T) { + singleTest(t, NATPreRedirectTCPOutgoing{}) +} + +func TestNATOutRedirectTCPIncoming(t *testing.T) { + singleTest(t, NATOutRedirectTCPIncoming{}) +} func TestNATOutRedirectUDPPort(t *testing.T) { singleTest(t, NATOutRedirectUDPPort{}) } @@ -315,3 +387,36 @@ func TestInputSource(t *testing.T) { func TestInputInvertSource(t *testing.T) { singleTest(t, FilterInputInvertSource{}) } + +func TestFilterAddrs(t *testing.T) { + tcs := []struct { + ipv6 bool + addrs []string + want []string + }{ + { + ipv6: false, + addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"}, + want: []string{"192.168.0.1", "192.168.0.2"}, + }, + { + ipv6: true, + addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"}, + want: []string{"::1", "::2"}, + }, + } + + for _, tc := range tcs { + if got := filterAddrs(tc.addrs, tc.ipv6); !reflect.DeepEqual(got, tc.want) { + t.Errorf("%v with IPv6 %t: got %v, but wanted %v", tc.addrs, tc.ipv6, got, tc.want) + } + } +} + +func TestNATPreOriginalDst(t *testing.T) { + singleTest(t, NATPreOriginalDst{}) +} + +func TestNATOutOriginalDst(t *testing.T) { + singleTest(t, NATOutOriginalDst{}) +} diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go new file mode 100644 index 000000000..bd85a8fea --- /dev/null +++ b/test/iptables/iptables_unsafe.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "syscall" + "unsafe" +) + +type originalDstError struct { + errno syscall.Errno +} + +func (e originalDstError) Error() string { + return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error()) +} + +// SO_ORIGINAL_DST gets the original destination of a redirected packet via +// getsockopt. +const SO_ORIGINAL_DST = 80 + +func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) { + var addr syscall.RawSockaddrInet4 + var addrLen uint32 = syscall.SizeofSockaddrInet4 + if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet4{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) { + var addr syscall.RawSockaddrInet6 + var addrLen uint32 = syscall.SizeofSockaddrInet6 + if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet6{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno { + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + uintptr(connfd), + level, + SO_ORIGINAL_DST, + uintptr(optval), + uintptr(unsafe.Pointer(optlen)), + 0) + return errno +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 7146edbb9..a6ec5cca3 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,30 +15,35 @@ package iptables import ( + "context" + "encoding/binary" + "errors" "fmt" "net" "os/exec" + "strings" "time" "gvisor.dev/gvisor/pkg/test/testutil" ) -const iptablesBinary = "iptables" -const localIP = "127.0.0.1" - -// filterTable calls `iptables -t filter` with the given args. -func filterTable(args ...string) error { - return tableCmd("filter", args) +// filterTable calls `ip{6}tables -t filter` with the given args. +func filterTable(ipv6 bool, args ...string) error { + return tableCmd(ipv6, "filter", args) } -// natTable calls `iptables -t nat` with the given args. -func natTable(args ...string) error { - return tableCmd("nat", args) +// natTable calls `ip{6}tables -t nat` with the given args. +func natTable(ipv6 bool, args ...string) error { + return tableCmd(ipv6, "nat", args) } -func tableCmd(table string, args []string) error { +func tableCmd(ipv6 bool, table string, args []string) error { args = append([]string{"-t", table}, args...) - cmd := exec.Command(iptablesBinary, args...) + binary := "iptables" + if ipv6 { + binary = "ip6tables" + } + cmd := exec.Command(binary, args...) if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) } @@ -46,18 +51,18 @@ func tableCmd(table string, args []string) error { } // filterTableRules is like filterTable, but runs multiple iptables commands. -func filterTableRules(argsList [][]string) error { - return tableRules("filter", argsList) +func filterTableRules(ipv6 bool, argsList [][]string) error { + return tableRules(ipv6, "filter", argsList) } // natTableRules is like natTable, but runs multiple iptables commands. -func natTableRules(argsList [][]string) error { - return tableRules("nat", argsList) +func natTableRules(ipv6 bool, argsList [][]string) error { + return tableRules(ipv6, "nat", argsList) } -func tableRules(table string, argsList [][]string) error { +func tableRules(ipv6 bool, table string, argsList [][]string) error { for _, args := range argsList { - if err := tableCmd(table, args); err != nil { + if err := tableCmd(ipv6, table, args); err != nil { return err } } @@ -66,77 +71,91 @@ func tableRules(table string, argsList [][]string) error { // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. -func listenUDP(port int, timeout time.Duration) error { +func listenUDP(ctx context.Context, port int) error { localAddr := net.UDPAddr{ Port: port, } - conn, err := net.ListenUDP(network, &localAddr) + conn, err := net.ListenUDP("udp", &localAddr) if err != nil { return err } defer conn.Close() - conn.SetDeadline(time.Now().Add(timeout)) - _, err = conn.Read([]byte{0}) - return err + + ch := make(chan error) + go func() { + _, err = conn.Read([]byte{0}) + ch <- err + }() + + select { + case err := <-ch: + return err + case <-ctx.Done(): + return ctx.Err() + } } // sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified // over a duration. -func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { - // Send packets for a few seconds. +func sendUDPLoop(ctx context.Context, ip net.IP, port int) error { remote := net.UDPAddr{ IP: ip, Port: port, } - conn, err := net.DialUDP(network, nil, &remote) + conn, err := net.DialUDP("udp", nil, &remote) if err != nil { return err } defer conn.Close() - to := time.After(duration) - for timedOut := false; !timedOut; { + for { // This may return an error (connection refused) if the remote // hasn't started listening yet or they're dropping our // packets. So we ignore Write errors and depend on the remote // to report a failure if it doesn't get a packet it needs. conn.Write([]byte{0}) select { - case <-to: - timedOut = true - default: - time.Sleep(200 * time.Millisecond) + case <-ctx.Done(): + // Being cancelled or timing out isn't an error, as we + // cannot tell with UDP whether we succeeded. + return nil + // Continue looping. + case <-time.After(200 * time.Millisecond): } } - - return nil } // listenTCP listens for connections on a TCP port. -func listenTCP(port int, timeout time.Duration) error { +func listenTCP(ctx context.Context, port int) error { localAddr := net.TCPAddr{ Port: port, } // Starts listening on port. - lConn, err := net.ListenTCP("tcp4", &localAddr) + lConn, err := net.ListenTCP("tcp", &localAddr) if err != nil { return err } defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - conn, err := lConn.AcceptTCP() - if err != nil { + ch := make(chan error) + go func() { + conn, err := lConn.AcceptTCP() + ch <- err + conn.Close() + }() + + select { + case err := <-ch: return err + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) } - conn.Close() - return nil } // connectTCP connects to the given IP and port from an ephemeral local address. -func connectTCP(ip net.IP, port int, timeout time.Duration) error { +func connectTCP(ctx context.Context, ip net.IP, port int) error { contAddr := net.TCPAddr{ IP: ip, Port: port, @@ -144,43 +163,120 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { // The container may not be listening when we first connect, so retry // upon error. callback := func() error { - conn, err := net.DialTimeout("tcp", contAddr.String(), timeout) + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", contAddr.String()) if conn != nil { conn.Close() } return err } - if err := testutil.Poll(callback, timeout); err != nil { + if err := testutil.PollContext(ctx, callback); err != nil { return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) } return nil } -// localAddrs returns a list of local network interface addresses. -func localAddrs() ([]string, error) { +// localAddrs returns a list of local network interface addresses. When ipv6 is +// true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are +// returned. +func localAddrs(ipv6 bool) ([]string, error) { addrs, err := net.InterfaceAddrs() if err != nil { return nil, err } addrStrs := make([]string, 0, len(addrs)) for _, addr := range addrs { - addrStrs = append(addrStrs, addr.String()) + // Add only IPv4 or only IPv6 addresses. + parts := strings.Split(addr.String(), "/") + if len(parts) != 2 { + return nil, fmt.Errorf("bad interface address: %q", addr.String()) + } + if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { + addrStrs = append(addrStrs, addr.String()) + } + } + return filterAddrs(addrStrs, ipv6), nil +} + +func filterAddrs(addrs []string, ipv6 bool) []string { + addrStrs := make([]string, 0, len(addrs)) + for _, addr := range addrs { + // Add only IPv4 or only IPv6 addresses. + parts := strings.Split(addr, "/") + if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { + addrStrs = append(addrStrs, parts[0]) + } } - return addrStrs, nil + return addrStrs } // getInterfaceName returns the name of the interface other than loopback. func getInterfaceName() (string, bool) { - var ifname string + iface, ok := getNonLoopbackInterface() + if !ok { + return "", false + } + return iface.Name, true +} + +func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) { + iface, ok := getNonLoopbackInterface() + if !ok { + return nil, errors.New("no non-loopback interface found") + } + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + // Get only IPv4 or IPv6 addresses. + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + parts := strings.Split(addr.String(), "/") + var ip net.IP + // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses. + // So we check whether To4() returns nil to test whether the + // address is v4 or v6. + if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil { + ip = net.ParseIP(parts[0]).To16() + } else { + ip = v4 + } + if ip != nil { + ips = append(ips, ip) + } + } + return ips, nil +} + +func getNonLoopbackInterface() (net.Interface, bool) { if interfaces, err := net.Interfaces(); err == nil { for _, intf := range interfaces { if intf.Name != "lo" { - ifname = intf.Name - break + return intf, true } } } + return net.Interface{}, false +} + +func htons(x uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, x) + return binary.LittleEndian.Uint16(buf) +} + +func localIP(ipv6 bool) string { + if ipv6 { + return "::1" + } + return "127.0.0.1" +} - return ifname, ifname != "" +func nowhereIP(ipv6 bool) string { + if ipv6 { + return "2001:db8::1" + } + return "192.0.2.1" } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 5e54a3963..dd9a18339 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -15,19 +15,20 @@ package iptables import ( + "context" "errors" "fmt" "net" - "time" + "syscall" ) -const ( - redirectPort = 42 -) +const redirectPort = 42 func init() { RegisterTestCase(NATPreRedirectUDPPort{}) RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectTCPOutgoing{}) + RegisterTestCase(NATOutRedirectTCPIncoming{}) RegisterTestCase(NATOutRedirectUDPPort{}) RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) @@ -40,10 +41,12 @@ func init() { RegisterTestCase(NATOutRedirectInvert{}) RegisterTestCase(NATRedirectRequiresProtocol{}) RegisterTestCase(NATLoopbackSkipsPrerouting{}) + RegisterTestCase(NATPreOriginalDst{}) + RegisterTestCase(NATOutOriginalDst{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. -type NATPreRedirectUDPPort struct{} +type NATPreRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectUDPPort) Name() string { @@ -51,12 +54,12 @@ func (NATPreRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { +func (NATPreRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(redirectPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, redirectPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) } @@ -64,12 +67,12 @@ func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectTCPPort tests that connections are redirected on specified ports. -type NATPreRedirectTCPPort struct{} +type NATPreRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPPort) Name() string { @@ -77,22 +80,72 @@ func (NATPreRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { +func (NATPreRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } // Listen for TCP packets on redirect port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) +} + +// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't +// affected by PREROUTING connection tracking. +type NATPreRedirectTCPOutgoing struct{ baseCase } + +// Name implements TestCase.Name. +func (NATPreRedirectTCPOutgoing) Name() string { + return "NATPreRedirectTCPOutgoing" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPOutgoing) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect all incoming TCP traffic to a closed port. + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return connectTCP(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPOutgoing) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenTCP(ctx, acceptPort) +} + +// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't +// affected by OUTPUT connection tracking. +type NATOutRedirectTCPIncoming struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutRedirectTCPIncoming) Name() string { + return "NATOutRedirectTCPIncoming" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPIncoming) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect all outgoing TCP traffic to a closed port. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATOutRedirectTCPIncoming) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // NATOutRedirectUDPPort tests that packets are redirected to different port. -type NATOutRedirectUDPPort struct{} +type NATOutRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATOutRedirectUDPPort) Name() string { @@ -100,20 +153,19 @@ func (NATOutRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectUDPPort) ContainerAction(ip net.IP) error { - dest := []byte{200, 0, 0, 1} - return loopbackTest(dest, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +func (NATOutRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectUDPPort) LocalAction(ip net.IP) error { +func (NATOutRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATDropUDP tests that packets are not received in ports other than redirect // port. -type NATDropUDP struct{} +type NATDropUDP struct{ containerCase } // Name implements TestCase.Name. func (NATDropUDP) Name() string { @@ -121,25 +173,29 @@ func (NATDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATDropUDP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { +func (NATDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (NATDropUDP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATAcceptAll tests that all UDP packets are accepted. -type NATAcceptAll struct{} +type NATAcceptAll struct{ containerCase } // Name implements TestCase.Name. func (NATAcceptAll) Name() string { @@ -147,12 +203,12 @@ func (NATAcceptAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATAcceptAll) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { +func (NATAcceptAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -160,13 +216,13 @@ func (NATAcceptAll) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATAcceptAll) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATAcceptAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATOutRedirectIP uses iptables to select packets based on destination IP and // redirects them. -type NATOutRedirectIP struct{} +type NATOutRedirectIP struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectIP) Name() string { @@ -174,21 +230,24 @@ func (NATOutRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectIP) ContainerAction(ip net.IP) error { +func (NATOutRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - dest := net.IP([]byte{200, 0, 0, 2}) - return loopbackTest(dest, "-A", "OUTPUT", "-d", dest.String(), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), + "-A", "OUTPUT", + "-d", nowhereIP(ipv6), + "-p", "udp", + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectIP) LocalAction(ip net.IP) error { +func (NATOutRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATOutDontRedirectIP struct{} +type NATOutDontRedirectIP struct{ localCase } // Name implements TestCase.Name. func (NATOutDontRedirectIP) Name() string { @@ -196,20 +255,20 @@ func (NATOutDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutDontRedirectIP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "OUTPUT", "-d", localIP, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { +func (NATOutDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutDontRedirectIP) LocalAction(ip net.IP) error { - return listenUDP(acceptPort, sendloopDuration) +func (NATOutDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // NATOutRedirectInvert tests that iptables can match with "! -d". -type NATOutRedirectInvert struct{} +type NATOutRedirectInvert struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectInvert) Name() string { @@ -217,22 +276,28 @@ func (NATOutRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectInvert) ContainerAction(ip net.IP) error { +func (NATOutRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - dest := []byte{200, 0, 0, 3} - destStr := "200.0.0.2" - return loopbackTest(dest, "-A", "OUTPUT", "!", "-d", destStr, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) + dest := "192.0.2.2" + if ipv6 { + dest = "2001:db8::2" + } + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), + "-A", "OUTPUT", + "!", "-d", dest, + "-p", "udp", + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectInvert) LocalAction(ip net.IP) error { +func (NATOutRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreRedirectIP tests that we can use iptables to select packets based on // destination IP and redirect them. -type NATPreRedirectIP struct{} +type NATPreRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectIP) Name() string { @@ -240,8 +305,8 @@ func (NATPreRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectIP) ContainerAction(ip net.IP) error { - addrs, err := localAddrs() +func (NATPreRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + addrs, err := localAddrs(ipv6) if err != nil { return err } @@ -250,20 +315,20 @@ func (NATPreRedirectIP) ContainerAction(ip net.IP) error { for _, addr := range addrs { rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)}) } - if err := natTableRules(rules); err != nil { + if err := natTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectIP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATPreDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATPreDontRedirectIP struct{} +type NATPreDontRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreDontRedirectIP) Name() string { @@ -271,20 +336,20 @@ func (NATPreDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { +func (NATPreDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreDontRedirectIP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectInvert tests that iptables can match with "! -d". -type NATPreRedirectInvert struct{} +type NATPreRedirectInvert struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectInvert) Name() string { @@ -292,21 +357,21 @@ func (NATPreRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectInvert) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "!", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { +func (NATPreRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectInvert) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a // protocol to be specified with -p. -type NATRedirectRequiresProtocol struct{} +type NATRedirectRequiresProtocol struct{ baseCase } // Name implements TestCase.Name. func (NATRedirectRequiresProtocol) Name() string { @@ -314,21 +379,21 @@ func (NATRedirectRequiresProtocol) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { +func (NATRedirectRequiresProtocol) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { return errors.New("expected an error using REDIRECT --to-ports without a protocol") } return nil } // LocalAction implements TestCase.LocalAction. -func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error { +func (NATRedirectRequiresProtocol) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutRedirectTCPPort tests that connections are redirected on specified ports. -type NATOutRedirectTCPPort struct{} +type NATOutRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPPort) Name() string { @@ -336,15 +401,13 @@ func (NATOutRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { +func (NATOutRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - timeout := 20 * time.Second - dest := []byte{127, 0, 0, 1} localAddr := net.TCPAddr{ - IP: dest, + IP: net.ParseIP(localIP(ipv6)), Port: acceptPort, } @@ -356,9 +419,7 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - err = connectTCP(ip, dropPort, timeout) - if err != nil { + if err := connectTCP(ctx, ip, dropPort); err != nil { return err } @@ -372,13 +433,13 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error { +func (NATOutRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { return nil } // NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't // affected by PREROUTING rules. -type NATLoopbackSkipsPrerouting struct{} +type NATLoopbackSkipsPrerouting struct{ baseCase } // Name implements TestCase.Name. func (NATLoopbackSkipsPrerouting) Name() string { @@ -386,10 +447,10 @@ func (NATLoopbackSkipsPrerouting) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error { +func (NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. dest := []byte{127, 0, 0, 1} - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } @@ -397,43 +458,200 @@ func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error { // loopback traffic, the connection would fail. sendCh := make(chan error) go func() { - sendCh <- connectTCP(dest, acceptPort, sendloopDuration) + sendCh <- connectTCP(ctx, dest, acceptPort) }() - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return err } return <-sendCh } // LocalAction implements TestCase.LocalAction. -func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP) error { +func (NATLoopbackSkipsPrerouting) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } -// loopbackTests runs an iptables rule and ensures that packets sent to -// dest:dropPort are received by localhost:acceptPort. -func loopbackTest(dest net.IP, args ...string) error { - if err := natTable(args...); err != nil { +// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of PREROUTING NATted packets. +type NATPreOriginalDst struct{ baseCase } + +// Name implements TestCase.Name. +func (NATPreOriginalDst) Name() string { + return "NATPreOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "PREROUTING", + "-p", "tcp", + "--destination-port", fmt.Sprintf("%d", dropPort), + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - sendCh := make(chan error) - listenCh := make(chan error) + + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + return listenForRedirectedConn(ctx, ipv6, addrs) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) +} + +// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of OUTBOUND NATted packets. +type NATOutOriginalDst struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutOriginalDst) Name() string { + return "NATOutOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + connCh := make(chan error) go func() { - sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) + connCh <- connectTCP(ctx, ip, dropPort) }() + + if err := listenForRedirectedConn(ctx, ipv6, []net.IP{ip}); err != nil { + return err + } + return <-connCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { + // The net package doesn't give guarantee access to the connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for SO_ORIGINAL_DST. + + // Create the listening socket, bind, listen, and accept. + family := syscall.AF_INET + if ipv6 { + family = syscall.AF_INET6 + } + sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0) + if err != nil { + return err + } + defer syscall.Close(sockfd) + + var bindAddr syscall.Sockaddr + if ipv6 { + bindAddr = &syscall.SockaddrInet6{ + Port: acceptPort, + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } else { + bindAddr = &syscall.SockaddrInet4{ + Port: acceptPort, + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + } + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return err + } + + if err := syscall.Listen(sockfd, 1); err != nil { + return err + } + + // Block on accept() in another goroutine. + connCh := make(chan int) + errCh := make(chan error) go func() { - listenCh <- listenUDP(acceptPort, sendloopDuration) + connFD, _, err := syscall.Accept(sockfd) + if err != nil { + errCh <- err + } + connCh <- connFD }() + + // Wait for accept() to return or for the context to finish. + var connFD int select { - case err := <-listenCh: + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case connFD = <-connCh: + } + defer syscall.Close(connFD) + + // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST + // indicates the packet was sent to originalDst:dropPort. + if ipv6 { + got, err := originalDestination6(connFD) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } else { + got, err := originalDestination4(connFD) if err != nil { return err } - case <-time.After(sendloopDuration): - return errors.New("timed out") + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } +} + +// loopbackTests runs an iptables rule and ensures that packets sent to +// dest:dropPort are received by localhost:acceptPort. +func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) error { + if err := natTable(ipv6, args...); err != nil { + return err + } + sendCh := make(chan error, 1) + listenCh := make(chan error, 1) + go func() { + sendCh <- sendUDPLoop(ctx, dest, dropPort) + }() + go func() { + listenCh <- listenUDP(ctx, acceptPort) + }() + select { + case err := <-listenCh: + return err + case err := <-sendCh: + return err } - // sendCh will always take the full sendloop time. - return <-sendCh } diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go index 6f77c0684..9ae2d1b4d 100644 --- a/test/iptables/runner/main.go +++ b/test/iptables/runner/main.go @@ -16,6 +16,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -24,7 +25,10 @@ import ( "gvisor.dev/gvisor/test/iptables" ) -var name = flag.String("name", "", "name of the test to run") +var ( + name = flag.String("name", "", "name of the test to run") + ipv6 = flag.Bool("ipv6", false, "whether the test utilizes ip6tables") +) func main() { flag.Parse() @@ -43,7 +47,9 @@ func main() { } // Run the test. - if err := test.ContainerAction(ip); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := test.ContainerAction(ctx, ip, *ipv6); err != nil { log.Fatalf("Failed running test %q: %v", *name, err) } @@ -57,7 +63,7 @@ func getIP() (net.IP, error) { localAddr := net.TCPAddr{ Port: iptables.IPExchangePort, } - listener, err := net.ListenTCP("tcp4", &localAddr) + listener, err := net.ListenTCP("tcp", &localAddr) if err != nil { return net.IP{}, fmt.Errorf("failed listening for IP: %v", err) } diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index dfcd55f60..49642f282 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -1,4 +1,5 @@ -load("defs.bzl", "packetdrill_test") +load("//tools:defs.bzl", "bzl_library") +load("//test/packetdrill:defs.bzl", "packetdrill_test") package(licenses = ["notice"]) @@ -36,3 +37,9 @@ packetdrill_test( name = "tcp_defer_accept_timeout_test", scripts = ["tcp_defer_accept_timeout.pkt"], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl index f499c177b..fc28ce9ba 100644 --- a/test/packetdrill/defs.bzl +++ b/test/packetdrill/defs.bzl @@ -26,7 +26,7 @@ def _packetdrill_test_impl(ctx): transitive_files = depset() if hasattr(ctx.attr._test_runner, "data_runfiles"): - transitive_files = depset(ctx.attr._test_runner.data_runfiles.files) + transitive_files = ctx.attr._test_runner.data_runfiles.files runfiles = ctx.runfiles( files = [test_runner] + ctx.files._init_script + ctx.files.scripts, transitive_files = transitive_files, @@ -60,11 +60,15 @@ _packetdrill_test = rule( implementation = _packetdrill_test_impl, ) -_PACKETDRILL_TAGS = ["local", "manual"] +PACKETDRILL_TAGS = [ + "local", + "manual", + "packetdrill", +] def packetdrill_linux_test(name, **kwargs): if "tags" not in kwargs: - kwargs["tags"] = _PACKETDRILL_TAGS + kwargs["tags"] = PACKETDRILL_TAGS _packetdrill_test( name = name, flags = ["--dut_platform", "linux"], @@ -73,7 +77,7 @@ def packetdrill_linux_test(name, **kwargs): def packetdrill_netstack_test(name, **kwargs): if "tags" not in kwargs: - kwargs["tags"] = _PACKETDRILL_TAGS + kwargs["tags"] = PACKETDRILL_TAGS _packetdrill_test( name = name, # This is the default runtime unless diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md index f46c67a0c..ffa96ba98 100644 --- a/test/packetimpact/README.md +++ b/test/packetimpact/README.md @@ -30,7 +30,7 @@ $ make load-packetimpact Run a test, e.g. `fin_wait2_timeout`, against Linux: ```bash -$ bazel test //test/packetimpact/tests:fin_wait2_timeout_linux_test +$ bazel test //test/packetimpact/tests:fin_wait2_timeout_native_test ``` Run the same test, but against gVisor: diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD index 3ce63c2c6..ccf1c735f 100644 --- a/test/packetimpact/dut/BUILD +++ b/test/packetimpact/dut/BUILD @@ -16,3 +16,13 @@ cc_binary( "//test/packetimpact/proto:posix_server_cc_proto", ], ) + +cc_binary( + name = "posix_server_dynamic", + srcs = ["posix_server.cc"], + deps = [ + grpcpp, + "//test/packetimpact/proto:posix_server_cc_grpc_proto", + "//test/packetimpact/proto:posix_server_cc_proto", + ], +) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index a1a5c3612..4de8540f6 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -21,6 +21,7 @@ #include <string.h> #include <sys/socket.h> #include <sys/types.h> +#include <time.h> #include <unistd.h> #include <iostream> @@ -28,6 +29,7 @@ #include "include/grpcpp/security/server_credentials.h" #include "include/grpcpp/server_builder.h" +#include "include/grpcpp/server_context.h" #include "test/packetimpact/proto/posix_server.grpc.pb.h" #include "test/packetimpact/proto/posix_server.pb.h" @@ -53,7 +55,10 @@ response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo)); response_in6->mutable_addr()->assign( reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16); - response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id)); + // sin6_scope_id is stored in host byte order. + // + // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html + response_in6->set_scope_id(addr_in6->sin6_scope_id); return ::grpc::Status::OK; } } @@ -89,7 +94,10 @@ addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo()); proto_in6.addr().copy( reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16); - addr_in6->sin6_scope_id = htonl(proto_in6.scope_id()); + // sin6_scope_id is stored in host byte order. + // + // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html + addr_in6->sin6_scope_id = proto_in6.scope_id(); *addr_len = sizeof(*addr_in6); break; } @@ -102,18 +110,20 @@ } class PosixImpl final : public posix_server::Posix::Service { - ::grpc::Status Accept(grpc_impl::ServerContext *context, + ::grpc::Status Accept(grpc::ServerContext *context, const ::posix_server::AcceptRequest *request, ::posix_server::AcceptResponse *response) override { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); response->set_fd(accept(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); - response->set_errno_(errno); + if (response->fd() < 0) { + response->set_errno_(errno); + } return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } - ::grpc::Status Bind(grpc_impl::ServerContext *context, + ::grpc::Status Bind(grpc::ServerContext *context, const ::posix_server::BindRequest *request, ::posix_server::BindResponse *response) override { if (!request->has_addr()) { @@ -130,19 +140,23 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret( bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Close(grpc_impl::ServerContext *context, + ::grpc::Status Close(grpc::ServerContext *context, const ::posix_server::CloseRequest *request, ::posix_server::CloseResponse *response) override { response->set_ret(close(request->fd())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Connect(grpc_impl::ServerContext *context, + ::grpc::Status Connect(grpc::ServerContext *context, const ::posix_server::ConnectRequest *request, ::posix_server::ConnectResponse *response) override { if (!request->has_addr()) { @@ -158,32 +172,38 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret(connect(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Fcntl(grpc_impl::ServerContext *context, + ::grpc::Status Fcntl(grpc::ServerContext *context, const ::posix_server::FcntlRequest *request, ::posix_server::FcntlResponse *response) override { response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } ::grpc::Status GetSockName( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::GetSockNameRequest *request, ::posix_server::GetSockNameResponse *response) override { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); response->set_ret(getsockname( request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } ::grpc::Status GetSockOpt( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::GetSockOptRequest *request, ::posix_server::GetSockOptResponse *response) override { switch (request->type()) { @@ -220,15 +240,19 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown SockOpt Type"); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Listen(grpc_impl::ServerContext *context, + ::grpc::Status Listen(grpc::ServerContext *context, const ::posix_server::ListenRequest *request, ::posix_server::ListenResponse *response) override { response->set_ret(listen(request->sockfd(), request->backlog())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -237,7 +261,9 @@ class PosixImpl final : public posix_server::Posix::Service { ::posix_server::SendResponse *response) override { response->set_ret(::send(request->sockfd(), request->buf().data(), request->buf().size(), request->flags())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -258,12 +284,14 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret(::sendto(request->sockfd(), request->buf().data(), request->buf().size(), request->flags(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } ::grpc::Status SetSockOpt( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::SetSockOptRequest *request, ::posix_server::SetSockOptResponse *response) override { switch (request->optval().val_case()) { @@ -280,9 +308,9 @@ class PosixImpl final : public posix_server::Posix::Service { break; } case ::posix_server::SockOptVal::kTimeval: { - timeval tv = {.tv_sec = static_cast<__time_t>( + timeval tv = {.tv_sec = static_cast<time_t>( request->optval().timeval().seconds()), - .tv_usec = static_cast<__suseconds_t>( + .tv_usec = static_cast<suseconds_t>( request->optval().timeval().microseconds())}; response->set_ret(setsockopt(request->sockfd(), request->level(), request->optname(), &tv, sizeof(tv))); @@ -292,16 +320,29 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown SockOpt Type"); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Socket(grpc_impl::ServerContext *context, + ::grpc::Status Socket(grpc::ServerContext *context, const ::posix_server::SocketRequest *request, ::posix_server::SocketResponse *response) override { response->set_fd( socket(request->domain(), request->type(), request->protocol())); - response->set_errno_(errno); + if (response->fd() < 0) { + response->set_errno_(errno); + } + return ::grpc::Status::OK; + } + + ::grpc::Status Shutdown(grpc::ServerContext *context, + const ::posix_server::ShutdownRequest *request, + ::posix_server::ShutdownResponse *response) override { + if (shutdown(request->fd(), request->how()) < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -314,7 +355,9 @@ class PosixImpl final : public posix_server::Posix::Service { if (response->ret() >= 0) { response->set_buf(buf.data(), response->ret()); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } }; diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD index 422bb9b0c..8d1193fed 100644 --- a/test/packetimpact/netdevs/BUILD +++ b/test/packetimpact/netdevs/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package( licenses = ["notice"], @@ -13,3 +13,11 @@ go_library( "//pkg/tcpip/header", ], ) + +go_test( + name = "netdevs_test", + size = "small", + srcs = ["netdevs_test.go"], + library = ":netdevs", + deps = ["@com_github_google_go_cmp//cmp:go_default_library"], +) diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go index d2c9cfeaf..eecfe0730 100644 --- a/test/packetimpact/netdevs/netdevs.go +++ b/test/packetimpact/netdevs/netdevs.go @@ -19,6 +19,7 @@ import ( "fmt" "net" "regexp" + "strconv" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -27,6 +28,7 @@ import ( // A DeviceInfo represents a network device. type DeviceInfo struct { + ID uint32 MAC net.HardwareAddr IPv4Addr net.IP IPv4Net *net.IPNet @@ -35,7 +37,7 @@ type DeviceInfo struct { } var ( - deviceLine = regexp.MustCompile(`^\s*\d+: (\w+)`) + deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`) linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`) inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`) inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`) @@ -43,6 +45,11 @@ var ( // ParseDevices parses the output from `ip addr show` into a map from device // name to information about the device. +// +// Note: if multiple IPv6 addresses are assigned to a device, the last address +// displayed by `ip addr show` will be used. This is fine for packetimpact +// because we will always only have at most one IPv6 address assigned to each +// device. func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) { var currentDevice string var currentInfo DeviceInfo @@ -52,8 +59,12 @@ func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) { if currentDevice != "" { deviceInfos[currentDevice] = currentInfo } - currentInfo = DeviceInfo{} - currentDevice = m[1] + id, err := strconv.ParseUint(m[1], 10, 32) + if err != nil { + return nil, fmt.Errorf("parsing device ID %s: %w", m[1], err) + } + currentInfo = DeviceInfo{ID: uint32(id)} + currentDevice = m[2] } else if m := linkLine.FindStringSubmatch(line); m != nil { mac, err := net.ParseMAC(m[1]) if err != nil { diff --git a/test/packetimpact/netdevs/netdevs_test.go b/test/packetimpact/netdevs/netdevs_test.go new file mode 100644 index 000000000..24ad12198 --- /dev/null +++ b/test/packetimpact/netdevs/netdevs_test.go @@ -0,0 +1,227 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package netdevs + +import ( + "fmt" + "net" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func mustParseMAC(s string) net.HardwareAddr { + mac, err := net.ParseMAC(s) + if err != nil { + panic(fmt.Sprintf("failed to parse test MAC %q: %s", s, err)) + } + return mac +} + +func TestParseDevices(t *testing.T) { + for _, v := range []struct { + desc string + cmdOutput string + want map[string]DeviceInfo + }{ + { + desc: "v4 and v6", + cmdOutput: ` +1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000 + link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00 + inet 127.0.0.1/8 scope host lo + valid_lft forever preferred_lft forever + inet6 ::1/128 scope host + valid_lft forever preferred_lft forever +2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0 + valid_lft forever preferred_lft forever + inet6 fe80::42:c0ff:fea8:902/64 scope link tentative + valid_lft forever preferred_lft forever +2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 223.245.225.10/24 brd 223.245.225.255 scope global eth2 + valid_lft forever preferred_lft forever + inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative + valid_lft forever preferred_lft forever +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "lo": DeviceInfo{ + ID: 1, + MAC: mustParseMAC("00:00:00:00:00:00"), + IPv4Addr: net.IPv4(127, 0, 0, 1), + IPv4Net: &net.IPNet{ + IP: net.IPv4(127, 0, 0, 0), + Mask: net.CIDRMask(8, 32), + }, + IPv6Addr: net.ParseIP("::1"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("::1"), + Mask: net.CIDRMask(128, 128), + }, + }, + "eth0": DeviceInfo{ + ID: 2613, + MAC: mustParseMAC("02:42:c0:a8:09:02"), + IPv4Addr: net.IPv4(192, 168, 9, 2), + IPv4Net: &net.IPNet{ + IP: net.IPv4(192, 168, 9, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:c0ff:fea8:902"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + "eth1": DeviceInfo{ + ID: 2617, + MAC: mustParseMAC("02:42:da:33:13:0a"), + IPv4Addr: net.IPv4(218, 51, 19, 10), + IPv4Net: &net.IPNet{ + IP: net.IPv4(218, 51, 19, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:daff:fe33:130a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + "eth2": DeviceInfo{ + ID: 2615, + MAC: mustParseMAC("02:42:df:f5:e1:0a"), + IPv4Addr: net.IPv4(223, 245, 225, 10), + IPv4Net: &net.IPNet{ + IP: net.IPv4(223, 245, 225, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + }, + }, + { + desc: "v4 only", + cmdOutput: ` +2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0 + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "eth0": DeviceInfo{ + ID: 2613, + MAC: mustParseMAC("02:42:c0:a8:09:02"), + IPv4Addr: net.IPv4(192, 168, 9, 2), + IPv4Net: &net.IPNet{ + IP: net.IPv4(192, 168, 9, 0), + Mask: net.CIDRMask(24, 32), + }, + }, + }, + }, + { + desc: "v6 only", + cmdOutput: ` +2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "eth2": DeviceInfo{ + ID: 2615, + MAC: mustParseMAC("02:42:df:f5:e1:0a"), + IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + }, + }, + } { + t.Run(v.desc, func(t *testing.T) { + got, err := ParseDevices(v.cmdOutput) + if err != nil { + t.Errorf("ParseDevices(\n%s\n) got unexpected error: %s", v.cmdOutput, err) + } + if diff := cmp.Diff(v.want, got); diff != "" { + t.Errorf("ParseDevices(\n%s\n) got output diff (-want, +got):\n%s", v.cmdOutput, diff) + } + }) + } +} + +func TestParseDevicesErrors(t *testing.T) { + for _, v := range []struct { + desc string + cmdOutput string + }{ + { + desc: "invalid MAC addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a:ffffffff brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid v4 addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 1234.4321.424242.0/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid v6 addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80:ffffffff::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid CIDR missing prefixlen", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a scope link tentative + valid_lft forever preferred_lft forever`, + }, + } { + t.Run(v.desc, func(t *testing.T) { + if _, err := ParseDevices(v.cmdOutput); err == nil { + t.Errorf("ParseDevices(\n%s\n) succeeded unexpectedly, want error", v.cmdOutput) + } + }) + } +} diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index ccd20b10d..f32ed54ef 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -188,6 +188,15 @@ message SocketResponse { int32 errno_ = 2; // "errno" may fail to compile in c++. } +message ShutdownRequest { + int32 fd = 1; + int32 how = 2; +} + +message ShutdownResponse { + int32 errno_ = 1; // "errno" may fail to compile in c++. +} + message RecvRequest { int32 sockfd = 1; int32 len = 2; @@ -225,6 +234,8 @@ service Posix { rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); // Call socket() on the DUT. rpc Socket(SocketRequest) returns (SocketResponse); + // Call shutdown() on the DUT. + rpc Shutdown(ShutdownRequest) returns (ShutdownResponse); // Call recv() on the DUT. rpc Recv(RecvRequest) returns (RecvResponse); } diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD index 0b68a760a..605dd4972 100644 --- a/test/packetimpact/runner/BUILD +++ b/test/packetimpact/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_test") +load("//tools:defs.bzl", "bzl_library", "go_library", "go_test") package( default_visibility = ["//test/packetimpact:__subpackages__"], @@ -7,14 +7,31 @@ package( go_test( name = "packetimpact_test", - srcs = ["packetimpact_test.go"], + srcs = [ + "packetimpact_test.go", + ], tags = [ # Not intended to be run directly. "local", "manual", ], + deps = [":runner"], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//test/packetimpact:__subpackages__"], +) + +go_library( + name = "runner", + testonly = True, + srcs = ["dut.go"], + visibility = ["//test/packetimpact:__subpackages__"], deps = [ "//pkg/test/dockerutil", "//test/packetimpact/netdevs", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 77cdfea12..f56d3c42e 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -23,8 +23,9 @@ def _packetimpact_test_impl(ctx): transitive_files = [] if hasattr(ctx.attr._test_runner, "data_runfiles"): transitive_files.append(ctx.attr._test_runner.data_runfiles.files) + files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server runfiles = ctx.runfiles( - files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, + files = files, transitive_files = depset(transitive = transitive_files), collect_default = True, collect_data = True, @@ -38,7 +39,7 @@ _packetimpact_test = rule( cfg = "target", default = ":packetimpact_test", ), - "_posix_server_binary": attr.label( + "_posix_server": attr.label( cfg = "target", default = "//test/packetimpact/dut:posix_server", ), @@ -55,14 +56,18 @@ _packetimpact_test = rule( implementation = _packetimpact_test_impl, ) -PACKETIMPACT_TAGS = ["local", "manual"] +PACKETIMPACT_TAGS = [ + "local", + "manual", + "packetimpact", +] -def packetimpact_linux_test( +def packetimpact_native_test( name, testbench_binary, expect_failure = False, **kwargs): - """Add a packetimpact test on linux. + """Add a native packetimpact test. Args: name: name of the test @@ -72,10 +77,10 @@ def packetimpact_linux_test( """ expect_failure_flag = ["--expect_failure"] if expect_failure else [] _packetimpact_test( - name = name + "_linux_test", + name = name + "_native_test", testbench_binary = testbench_binary, - flags = ["--dut_platform", "linux"] + expect_failure_flag, - tags = PACKETIMPACT_TAGS + ["packetimpact"], + flags = ["--native"] + expect_failure_flag, + tags = PACKETIMPACT_TAGS, **kwargs ) @@ -98,21 +103,21 @@ def packetimpact_netstack_test( _packetimpact_test( name = name + "_netstack_test", testbench_binary = testbench_binary, - # This is the default runtime unless - # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. - flags = ["--dut_platform", "netstack", "--runtime=runsc-d"] + expect_failure_flag, - tags = PACKETIMPACT_TAGS + ["packetimpact"], + # Note that a distinct runtime must be provided in the form + # --test_arg=--runtime=other when invoking bazel. + flags = expect_failure_flag, + tags = PACKETIMPACT_TAGS, **kwargs ) -def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure = False, expect_netstack_failure = False, **kwargs): +def packetimpact_go_test(name, size = "small", pure = True, expect_native_failure = False, expect_netstack_failure = False, **kwargs): """Add packetimpact tests written in go. Args: name: name of the test size: size of the test pure: make a static go binary - expect_linux_failure: the test must fail for Linux + expect_native_failure: the test must fail natively expect_netstack_failure: the test must fail for Netstack **kwargs: all the other args, forwarded to go_test """ @@ -121,12 +126,16 @@ def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure name = testbench_binary, size = size, pure = pure, - tags = PACKETIMPACT_TAGS, + nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + tags = [ + "local", + "manual", + ], **kwargs ) - packetimpact_linux_test( + packetimpact_native_test( name = name, - expect_failure = expect_linux_failure, + expect_failure = expect_native_failure, testbench_binary = testbench_binary, ) packetimpact_netstack_test( diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go new file mode 100644 index 000000000..59bb68eb1 --- /dev/null +++ b/test/packetimpact/runner/dut.go @@ -0,0 +1,442 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package runner starts docker containers and networking for a packetimpact test. +package runner + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +// stringList implements flag.Value. +type stringList []string + +// String implements flag.Value.String. +func (l *stringList) String() string { + return strings.Join(*l, ",") +} + +// Set implements flag.Value.Set. +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +var ( + native = false + testbenchBinary = "" + tshark = false + extraTestArgs = stringList{} + expectFailure = false + + // DutAddr is the IP addres for DUT. + DutAddr = net.IPv4(0, 0, 0, 10) + testbenchAddr = net.IPv4(0, 0, 0, 20) +) + +// RegisterFlags defines flags and associates them with the package-level +// exported variables above. It should be called by tests in their init +// functions. +func RegisterFlags(fs *flag.FlagSet) { + fs.BoolVar(&native, "native", false, "whether the test should be run natively") + fs.StringVar(&testbenchBinary, "testbench_binary", "", "path to the testbench binary") + fs.BoolVar(&tshark, "tshark", false, "use more verbose tshark in logs instead of tcpdump") + fs.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") + fs.BoolVar(&expectFailure, "expect_failure", false, "expect that the test will fail when run") +} + +// CtrlPort is the port that posix_server listens on. +const CtrlPort = "40000" + +// logger implements testutil.Logger. +// +// Labels logs based on their source and formats multi-line logs. +type logger string + +// Name implements testutil.Logger.Name. +func (l logger) Name() string { + return string(l) +} + +// Logf implements testutil.Logger.Logf. +func (l logger) Logf(format string, args ...interface{}) { + lines := strings.Split(fmt.Sprintf(format, args...), "\n") + log.Printf("%s: %s", l, lines[0]) + for _, line := range lines[1:] { + log.Printf("%*s %s", len(l), "", line) + } +} + +// TestWithDUT runs a packetimpact test with the given information. +func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT, containerAddr net.IP) { + if testbenchBinary == "" { + t.Fatal("--testbench_binary is missing") + } + dockerutil.EnsureSupportedDockerVersion() + + // Create the networks needed for the test. One control network is needed for + // the gRPC control packets and one test network on which to transmit the test + // packets. + ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) + testNet := dockerutil.NewNetwork(ctx, logger("testNet")) + for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { + for { + if err := createDockerNetwork(ctx, dn); err != nil { + t.Log("creating docker network:", err) + const wait = 100 * time.Millisecond + t.Logf("sleeping %s and will try creating docker network again", wait) + // This can fail if another docker network claimed the same IP so we'll + // just try again. + time.Sleep(wait) + continue + } + break + } + dn := dn + t.Cleanup(func() { + if err := dn.Cleanup(ctx); err != nil { + t.Errorf("unable to cleanup container %s: %s", dn.Name, err) + } + }) + // Sanity check. + if inspect, err := dn.Inspect(ctx); err != nil { + t.Fatalf("failed to inspect network %s: %v", dn.Name, err) + } else if inspect.Name != dn.Name { + t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) + } + } + + tmpDir, err := ioutil.TempDir("", "container-output") + if err != nil { + t.Fatal("creating temp dir:", err) + } + t.Cleanup(func() { + if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { + t.Errorf("unable to copy container output files: %s", err) + } + if err := os.RemoveAll(tmpDir); err != nil { + t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err) + } + }) + + const testOutputDir = "/tmp/testoutput" + + // Create the Docker container for the DUT. + var dut *dockerutil.Container + if native { + dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) + } else { + dut = dockerutil.MakeContainer(ctx, logger("dut")) + } + t.Cleanup(func() { + dut.CleanUp(ctx) + }) + + runOpts := dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + + device := mkDevice(dut) + remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr) + + // Create the Docker container for the testbench. + testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) + + tbb := path.Base(testbenchBinary) + containerTestbenchBinary := filepath.Join("/packetimpact", tbb) + testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb)) + + // snifferNetDev is a network device on the test orchestrator that we will + // run sniffer (tcpdump or tshark) on and inject traffic to, not to be + // confused with the device on the DUT. + const snifferNetDev = "eth2" + // Run tcpdump in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs := []string{ + "tcpdump", + "-S", "-vvv", "-U", "-n", + "-i", snifferNetDev, + "-w", testOutputDir + "/dump.pcap", + } + snifferRegex := "tcpdump: listening.*\n" + if tshark { + // Run tshark in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs = []string{ + "tshark", "-V", "-l", "-n", "-i", snifferNetDev, + "-o", "tcp.check_checksum:TRUE", + "-o", "udp.check_checksum:TRUE", + } + snifferRegex = "Capturing on.*\n" + } + + if err := StartContainer( + ctx, + runOpts, + testbench, + testbenchAddr, + []*dockerutil.Network{ctrlNet, testNet}, + snifferArgs..., + ); err != nil { + t.Fatalf("failed to start docker container for testbench sniffer: %s", err) + } + // Kill so that it will flush output. + t.Cleanup(func() { + time.Sleep(1 * time.Second) + testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) + }) + + if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { + t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) + } + + // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it + // will respond with an RST. In most packetimpact tests, the SYN is sent + // by the raw socket and the kernel knows nothing about the connection, this + // behavior will break lots of TCP related packetimpact tests. To prevent + // this, we can install the following iptables rules. The raw socket that + // packetimpact tests use will still be able to see everything. + for _, bin := range []string{"iptables", "ip6tables"} { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) + } + } + + // FIXME(b/156449515): Some piece of the system has a race. The old + // bash script version had a sleep, so we have one too. The race should + // be fixed and this sleep removed. + time.Sleep(time.Second) + + // Start a packetimpact test on the test bench. The packetimpact test sends + // and receives packets and also sends POSIX socket commands to the + // posix_server to be executed on the DUT. + testArgs := []string{containerTestbenchBinary} + testArgs = append(testArgs, extraTestArgs...) + testArgs = append(testArgs, + "--posix_server_ip", AddressInSubnet(DutAddr, *ctrlNet.Subnet).String(), + "--posix_server_port", CtrlPort, + "--remote_ipv4", AddressInSubnet(DutAddr, *testNet.Subnet).String(), + "--local_ipv4", AddressInSubnet(testbenchAddr, *testNet.Subnet).String(), + "--remote_ipv6", remoteIPv6.String(), + "--remote_mac", remoteMAC.String(), + "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID), + "--local_device", snifferNetDev, + "--remote_device", dutTestNetDev, + fmt.Sprintf("--native=%t", native), + ) + testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) + if (err != nil) != expectFailure { + var dutLogs string + if logs, err := device.Logs(ctx); err != nil { + dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) + } else { + dutLogs = logs + } + + t.Errorf(`test error: %v, expect failure: %t + +%s + +====== Begin of Testbench Logs ====== + +%s + +====== End of Testbench Logs ======`, + err, expectFailure, dutLogs, testbenchLogs) + } +} + +// DUT describes how to setup/teardown the dut for packetimpact tests. +type DUT interface { + // Prepare prepares the dut, starts posix_server and returns the IPv6, MAC + // address, the interface ID, and the interface name for the testNet on DUT. + Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) + // Logs retrieves the logs from the dut. + Logs(ctx context.Context) (string, error) +} + +// DockerDUT describes a docker based DUT. +type DockerDUT struct { + c *dockerutil.Container +} + +// NewDockerDUT creates a docker based DUT. +func NewDockerDUT(c *dockerutil.Container) DUT { + return &DockerDUT{ + c: c, + } +} + +// Prepare implements DUT.Prepare. +func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) { + const containerPosixServerBinary = "/packetimpact/posix_server" + dut.c.CopyFiles(&runOpts, "/packetimpact", "test/packetimpact/dut/posix_server") + + if err := StartContainer( + ctx, + runOpts, + dut.c, + containerAddr, + []*dockerutil.Network{ctrlNet, testNet}, + containerPosixServerBinary, + "--ip=0.0.0.0", + "--port="+CtrlPort, + ); err != nil { + t.Fatalf("failed to start docker container for DUT: %s", err) + } + + if _, err := dut.c.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { + t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err) + } + + dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + + remoteMAC := dutDeviceInfo.MAC + remoteIPv6 := dutDeviceInfo.IPv6Addr + // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if + // needed. + if remoteIPv6 == nil { + if _, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { + t.Fatalf("unable to ip addr add on container %s: %s", dut.c.Name, err) + } + // Now try again, to make sure that it worked. + _, dutDeviceInfo, err = deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + remoteIPv6 = dutDeviceInfo.IPv6Addr + if remoteIPv6 == nil { + t.Fatalf("unable to set IPv6 address on container %s", dut.c.Name) + } + } + const testNetDev = "eth2" + + return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev +} + +// Logs implements DUT.Logs. +func (dut *DockerDUT) Logs(ctx context.Context) (string, error) { + logs, err := dut.c.Logs(ctx) + if err != nil { + return "", err + } + return fmt.Sprintf(`====== Begin of DUT Logs ====== + +%s + +====== End of DUT Logs ======`, logs), nil +} + +// AddNetworks connects docker network with the container and assigns the specific IP. +func AddNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { + for _, dn := range networks { + ip := AddressInSubnet(addr, *dn.Subnet) + // Connect to the network with the specified IP address. + if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { + return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) + } + } + return nil +} + +// AddressInSubnet combines the subnet provided with the address and returns a +// new address. The return address bits come from the subnet where the mask is 1 +// and from the ip address where the mask is 0. +func AddressInSubnet(addr net.IP, subnet net.IPNet) net.IP { + var octets []byte + for i := 0; i < 4; i++ { + octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) + } + return net.IP(octets) +} + +// deviceByIP finds a deviceInfo and device name from an IP address. +func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out) + } + devs, err := netdevs.ParseDevices(out) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out) + } + testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) + } + return testDevice, deviceInfo, nil +} + +// createDockerNetwork makes a randomly-named network that will start with the +// namePrefix. The network will be a random /24 subnet. +func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { + randSource := rand.NewSource(time.Now().UnixNano()) + r1 := rand.New(randSource) + // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) + n.Subnet = &net.IPNet{ + IP: ip, + Mask: ip.DefaultMask(), + } + return n.Create(ctx) +} + +// StartContainer will create a container instance from runOpts, connect it +// with the specified docker networks and start executing the specified cmd. +func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, cmd ...string) error { + conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...) + _ = netconf + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := c.CreateFrom(ctx, conf, hostconf, nil); err != nil { + return fmt.Errorf("unable to create container %s: %w", c.Name, err) + } + + if err := AddNetworks(ctx, c, containerAddr, ns); err != nil { + return fmt.Errorf("unable to connect the container with the networks: %w", err) + } + + if err := c.Start(ctx); err != nil { + return fmt.Errorf("unable to start container %s: %w", c.Name, err) + } + return nil +} diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go index c0a2620de..c598bfc29 100644 --- a/test/packetimpact/runner/packetimpact_test.go +++ b/test/packetimpact/runner/packetimpact_test.go @@ -16,330 +16,17 @@ package packetimpact_test import ( + "context" "flag" - "fmt" - "io/ioutil" - "log" - "math/rand" - "net" - "os" - "os/exec" - "path" - "strings" "testing" - "time" - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/packetimpact/netdevs" + "gvisor.dev/gvisor/test/packetimpact/runner" ) -// stringList implements flag.Value. -type stringList []string - -// String implements flag.Value.String. -func (l *stringList) String() string { - return strings.Join(*l, ",") -} - -// Set implements flag.Value.Set. -func (l *stringList) Set(value string) error { - *l = append(*l, value) - return nil -} - -var ( - dutPlatform = flag.String("dut_platform", "", "either \"linux\" or \"netstack\"") - testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary") - tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump") - extraTestArgs = stringList{} - expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run") - - dutAddr = net.IPv4(0, 0, 0, 10) - testbenchAddr = net.IPv4(0, 0, 0, 20) -) - -const ctrlPort = "40000" - -// logger implements testutil.Logger. -// -// Labels logs based on their source and formats multi-line logs. -type logger string - -// Name implements testutil.Logger.Name. -func (l logger) Name() string { - return string(l) -} - -// Logf implements testutil.Logger.Logf. -func (l logger) Logf(format string, args ...interface{}) { - lines := strings.Split(fmt.Sprintf(format, args...), "\n") - log.Printf("%s: %s", l, lines[0]) - for _, line := range lines[1:] { - log.Printf("%*s %s", len(l), "", line) - } +func init() { + runner.RegisterFlags(flag.CommandLine) } func TestOne(t *testing.T) { - flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") - flag.Parse() - if *dutPlatform != "linux" && *dutPlatform != "netstack" { - t.Fatal("--dut_platform should be either linux or netstack") - } - if *testbenchBinary == "" { - t.Fatal("--testbench_binary is missing") - } - if *dutPlatform == "netstack" { - if _, err := dockerutil.RuntimePath(); err != nil { - t.Fatal("--runtime is missing or invalid with --dut_platform=netstack:", err) - } - } - dockerutil.EnsureSupportedDockerVersion() - - // Create the networks needed for the test. One control network is needed for - // the gRPC control packets and one test network on which to transmit the test - // packets. - ctrlNet := dockerutil.NewDockerNetwork(logger("ctrlNet")) - testNet := dockerutil.NewDockerNetwork(logger("testNet")) - for _, dn := range []*dockerutil.DockerNetwork{ctrlNet, testNet} { - for { - if err := createDockerNetwork(dn); err != nil { - t.Log("creating docker network:", err) - const wait = 100 * time.Millisecond - t.Logf("sleeping %s and will try creating docker network again", wait) - // This can fail if another docker network claimed the same IP so we'll - // just try again. - time.Sleep(wait) - continue - } - break - } - defer func(dn *dockerutil.DockerNetwork) { - if err := dn.Cleanup(); err != nil { - t.Errorf("unable to cleanup container %s: %s", dn.Name, err) - } - }(dn) - } - - tmpDir, err := ioutil.TempDir("", "container-output") - if err != nil { - t.Fatal("creating temp dir:", err) - } - defer os.RemoveAll(tmpDir) - - const testOutputDir = "/tmp/testoutput" - - runOpts := dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Extra: []string{"--sysctl", "net.ipv6.conf.all.disable_ipv6=0", "--rm", "-v", tmpDir + ":" + testOutputDir}, - Foreground: true, - } - - // Create the Docker container for the DUT. - dut := dockerutil.MakeDocker(logger("dut")) - if *dutPlatform == "linux" { - dut.Runtime = "" - } - - const containerPosixServerBinary = "/packetimpact/posix_server" - dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server") - - if err := dut.Create(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort); err != nil { - t.Fatalf("unable to create container %s: %s", dut.Name, err) - } - defer dut.CleanUp() - - // Add ctrlNet as eth1 and testNet as eth2. - const testNetDev = "eth2" - if err := addNetworks(dut, dutAddr, []*dockerutil.DockerNetwork{ctrlNet, testNet}); err != nil { - t.Fatal(err) - } - - if err := dut.Start(); err != nil { - t.Fatalf("unable to start container %s: %s", dut.Name, err) - } - - if _, err := dut.WaitForOutput("Server listening.*\n", 60*time.Second); err != nil { - t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err) - } - - dutTestDevice, dutDeviceInfo, err := deviceByIP(dut, addressInSubnet(dutAddr, *testNet.Subnet)) - if err != nil { - t.Fatal(err) - } - - remoteMAC := dutDeviceInfo.MAC - remoteIPv6 := dutDeviceInfo.IPv6Addr - // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if - // needed. - if remoteIPv6 == nil { - if _, err := dut.Exec(dockerutil.RunOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { - t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err) - } - // Now try again, to make sure that it worked. - _, dutDeviceInfo, err = deviceByIP(dut, addressInSubnet(dutAddr, *testNet.Subnet)) - if err != nil { - t.Fatal(err) - } - remoteIPv6 = dutDeviceInfo.IPv6Addr - if remoteIPv6 == nil { - t.Fatal("unable to set IPv6 address on container", dut.Name) - } - } - - // Create the Docker container for the testbench. - testbench := dockerutil.MakeDocker(logger("testbench")) - testbench.Runtime = "" // The testbench always runs on Linux. - - tbb := path.Base(*testbenchBinary) - containerTestbenchBinary := "/packetimpact/" + tbb - runOpts = dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Extra: []string{"--sysctl", "net.ipv6.conf.all.disable_ipv6=0", "--rm", "-v", tmpDir + ":" + testOutputDir}, - Foreground: true, - } - testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb) - - // Run tcpdump in the test bench unbuffered, without DNS resolution, just on - // the interface with the test packets. - snifferArgs := []string{ - "tcpdump", - "-S", "-vvv", "-U", "-n", - "-i", testNetDev, - "-w", testOutputDir + "/dump.pcap", - } - snifferRegex := "tcpdump: listening.*\n" - if *tshark { - // Run tshark in the test bench unbuffered, without DNS resolution, just on - // the interface with the test packets. - snifferArgs = []string{ - "tshark", "-V", "-l", "-n", "-i", testNetDev, - "-o", "tcp.check_checksum:TRUE", - "-o", "udp.check_checksum:TRUE", - } - snifferRegex = "Capturing on.*\n" - } - - defer func() { - if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { - t.Error("unable to copy container output files:", err) - } - }() - - if err := testbench.Create(runOpts, snifferArgs...); err != nil { - t.Fatalf("unable to create container %s: %s", testbench.Name, err) - } - defer testbench.CleanUp() - - // Add ctrlNet as eth1 and testNet as eth2. - if err := addNetworks(testbench, testbenchAddr, []*dockerutil.DockerNetwork{ctrlNet, testNet}); err != nil { - t.Fatal(err) - } - - if err := testbench.Start(); err != nil { - t.Fatalf("unable to start container %s: %s", testbench.Name, err) - } - - // Kill so that it will flush output. - defer func() { - // Wait 1 second before killing tcpdump to give it time to flush - // any packets. On linux tests killing it immediately can - // sometimes result in partial pcaps. - time.Sleep(1 * time.Second) - testbench.Exec(dockerutil.RunOpts{}, "killall", snifferArgs[0]) - }() - - if _, err := testbench.WaitForOutput(snifferRegex, 60*time.Second); err != nil { - t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) - } - - // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it - // will issue a RST. To prevent this IPtables can be used to filter out all - // incoming packets. The raw socket that packetimpact tests use will still see - // everything. - if _, err := testbench.Exec(dockerutil.RunOpts{}, "iptables", "-A", "INPUT", "-i", testNetDev, "-j", "DROP"); err != nil { - t.Fatalf("unable to Exec iptables on container %s: %s", testbench.Name, err) - } - - // FIXME(b/156449515): Some piece of the system has a race. The old - // bash script version had a sleep, so we have one too. The race should - // be fixed and this sleep removed. - time.Sleep(time.Second) - - // Start a packetimpact test on the test bench. The packetimpact test sends - // and receives packets and also sends POSIX socket commands to the - // posix_server to be executed on the DUT. - testArgs := []string{containerTestbenchBinary} - testArgs = append(testArgs, extraTestArgs...) - testArgs = append(testArgs, - "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(), - "--posix_server_port", ctrlPort, - "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(), - "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(), - "--remote_ipv6", remoteIPv6.String(), - "--remote_mac", remoteMAC.String(), - "--device", testNetDev, - "--dut_type", *dutPlatform, - ) - _, err = testbench.Exec(dockerutil.RunOpts{}, testArgs...) - if !*expectFailure && err != nil { - t.Fatal("test failed:", err) - } - if *expectFailure && err == nil { - t.Fatal("test failure expected but the test succeeded, enable the test and mark the corresponding bug as fixed") - } -} - -func addNetworks(d *dockerutil.Docker, addr net.IP, networks []*dockerutil.DockerNetwork) error { - for _, dn := range networks { - ip := addressInSubnet(addr, *dn.Subnet) - // Connect to the network with the specified IP address. - if err := dn.Connect(d, "--ip", ip.String()); err != nil { - return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) - } - } - return nil -} - -// addressInSubnet combines the subnet provided with the address and returns a -// new address. The return address bits come from the subnet where the mask is 1 -// and from the ip address where the mask is 0. -func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP { - var octets []byte - for i := 0; i < 4; i++ { - octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) - } - return net.IP(octets) -} - -// makeDockerNetwork makes a randomly-named network that will start with the -// namePrefix. The network will be a random /24 subnet. -func createDockerNetwork(n *dockerutil.DockerNetwork) error { - randSource := rand.NewSource(time.Now().UnixNano()) - r1 := rand.New(randSource) - // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. - ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) - n.Subnet = &net.IPNet{ - IP: ip, - Mask: ip.DefaultMask(), - } - return n.Create() -} - -// deviceByIP finds a deviceInfo and device name from an IP address. -func deviceByIP(d *dockerutil.Docker, ip net.IP) (string, netdevs.DeviceInfo, error) { - out, err := d.Exec(dockerutil.RunOpts{}, "ip", "addr", "show") - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err) - } - devs, err := netdevs.ParseDevices(out) - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err) - } - testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) - } - return testDevice, deviceInfo, nil + runner.TestWithDUT(context.Background(), t, runner.NewDockerDUT, runner.DutAddr) } diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index d19ec07d4..5a0ee1367 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -23,8 +23,8 @@ go_library( "//pkg/usermem", "//test/packetimpact/netdevs", "//test/packetimpact/proto:posix_server_go_proto", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", "@com_github_mohae_deepcopy//:go_default_library", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//keepalive:go_default_library", diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 8b4a4d905..a90046f69 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -41,16 +41,19 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) } -// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with +// pickPort makes a new socket and returns the socket FD and port. The domain +// should be AF_INET or AF_INET6. The caller must close the FD when done with // the port if there is no error. -func pickPort(domain, typ int) (int, uint16, error) { - fd, err := unix.Socket(domain, typ, 0) +func pickPort(domain, typ int) (fd int, port uint16, err error) { + fd, err = unix.Socket(domain, typ, 0) if err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("creating socket: %w", err) } defer func() { if err != nil { - err = multierr.Append(err, unix.Close(fd)) + if cerr := unix.Close(fd); cerr != nil { + err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr)) + } } }() var sa unix.Sockaddr @@ -60,22 +63,22 @@ func pickPort(domain, typ int) (int, uint16, error) { copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4()) sa = &sa4 case unix.AF_INET6: - var sa6 unix.SockaddrInet6 + sa6 := unix.SockaddrInet6{ZoneId: uint32(LocalInterfaceID)} copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16()) sa = &sa6 default: return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) } if err = unix.Bind(fd, sa); err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err) } sa, err = unix.Getsockname(fd) if err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("Getsocketname(%d): %w", fd, err) } - port, err := portFromSockaddr(sa) + port, err = portFromSockaddr(sa) if err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err) } return fd, port, nil } @@ -378,7 +381,7 @@ var _ layerState = (*udpState)(nil) func newUDPState(domain int, out, in UDP) (*udpState, error) { portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM) if err != nil { - return nil, err + return nil, fmt.Errorf("picking port: %w", err) } s := udpState{ out: UDP{SrcPort: &localPort}, @@ -426,7 +429,6 @@ type Connection struct { layerStates []layerState injector Injector sniffer Sniffer - t *testing.T } // Returns the default incoming frame against which to match. If received is @@ -459,7 +461,9 @@ func (conn *Connection) match(override, received Layers) bool { } // Close frees associated resources held by the Connection. -func (conn *Connection) Close() { +func (conn *Connection) Close(t *testing.T) { + t.Helper() + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) for _, s := range conn.layerStates { if err := s.close(); err != nil { @@ -467,7 +471,7 @@ func (conn *Connection) Close() { } } if errs != nil { - conn.t.Fatalf("unable to close %+v: %s", conn, errs) + t.Fatalf("unable to close %+v: %s", conn, errs) } } @@ -479,7 +483,9 @@ func (conn *Connection) Close() { // overriden first. As an example, valid values of overrideLayers for a TCP- // over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and // [Ethernet, IPv4, TCP]. -func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers { +func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers { + t.Helper() + var layersToSend Layers for i, s := range conn.layerStates { layer := s.outgoing() @@ -488,7 +494,7 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L // end. if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { if err := layer.merge(overrideLayers[j]); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) + t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) } } layersToSend = append(layersToSend, layer) @@ -502,21 +508,25 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L // This method is useful for sending out-of-band control messages such as // ICMP packets, where it would not make sense to update the transport layer's // state using the ICMP header. -func (conn *Connection) SendFrameStateless(frame Layers) { +func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) { + t.Helper() + outBytes, err := frame.ToBytes() if err != nil { - conn.t.Fatalf("can't build outgoing packet: %s", err) + t.Fatalf("can't build outgoing packet: %s", err) } - conn.injector.Send(outBytes) + conn.injector.Send(t, outBytes) } // SendFrame sends a frame on the wire and updates the state of all layers. -func (conn *Connection) SendFrame(frame Layers) { +func (conn *Connection) SendFrame(t *testing.T, frame Layers) { + t.Helper() + outBytes, err := frame.ToBytes() if err != nil { - conn.t.Fatalf("can't build outgoing packet: %s", err) + t.Fatalf("can't build outgoing packet: %s", err) } - conn.injector.Send(outBytes) + conn.injector.Send(t, outBytes) // frame might have nil values where the caller wanted to use default values. // sentFrame will have no nil values in it because it comes from parsing the @@ -525,7 +535,7 @@ func (conn *Connection) SendFrame(frame Layers) { // Update the state of each layer based on what was sent. for i, s := range conn.layerStates { if err := s.sent(sentFrame[i]); err != nil { - conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) } } } @@ -535,18 +545,22 @@ func (conn *Connection) SendFrame(frame Layers) { // // Types defined with Connection as the underlying type should expose // type-safe versions of this method. -func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...)) +func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + t.Helper() + + conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...)) } // recvFrame gets the next successfully parsed frame (of type Layers) within the // timeout provided. If no parsable frame arrives before the timeout, it returns // nil. -func (conn *Connection) recvFrame(timeout time.Duration) Layers { +func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers { + t.Helper() + if timeout <= 0 { return nil } - b := conn.sniffer.Recv(timeout) + b := conn.sniffer.Recv(t, timeout) if b == nil { return nil } @@ -566,43 +580,47 @@ func (e *layersError) Error() string { // Expect expects a frame with the final layerStates layer matching the // provided Layer within the timeout specified. If it doesn't arrive in time, // an error is returned. -func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { +func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) { + t.Helper() + // Make a frame that will ignore all but the final layer. layers := make([]Layer, len(conn.layerStates)) layers[len(layers)-1] = layer - gotFrame, err := conn.ExpectFrame(layers, timeout) + gotFrame, err := conn.ExpectFrame(t, layers, timeout) if err != nil { return nil, err } if len(conn.layerStates)-1 < len(gotFrame) { return gotFrame[len(conn.layerStates)-1], nil } - conn.t.Fatal("the received frame should be at least as long as the expected layers") + t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame) panic("unreachable") } // ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If one arrives in time, the Layers is returned without an // error. If it doesn't arrive in time, it returns nil and error is non-nil. -func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { +func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { + t.Helper() + deadline := time.Now().Add(timeout) var errs error for { var gotLayers Layers if timeout = time.Until(deadline); timeout > 0 { - gotLayers = conn.recvFrame(timeout) + gotLayers = conn.recvFrame(t, timeout) } if gotLayers == nil { if errs == nil { return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout) } - return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs) + return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout) } if conn.match(layers, gotLayers) { for i, s := range conn.layerStates { if err := s.received(gotLayers[i]); err != nil { - conn.t.Fatal(err) + t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) } } return gotLayers, nil @@ -613,8 +631,10 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *Connection) Drain() { - conn.sniffer.Drain() +func (conn *Connection) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. @@ -622,6 +642,8 @@ type TCPIPv4 Connection // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -647,57 +669,58 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { layerStates: []layerState{etherState, ipv4State, tcpState}, injector: injector, sniffer: sniffer, - t: t, } } // Connect performs a TCP 3-way handshake. The input Connection should have a // final TCP Layer. -func (conn *TCPIPv4) Connect() { - conn.t.Helper() +func (conn *TCPIPv4) Connect(t *testing.T) { + t.Helper() // Send the SYN. - conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { - conn.t.Fatalf("didn't get synack during handshake: %s", err) + t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) } // ConnectWithOptions performs a TCP 3-way handshake with given TCP options. // The input Connection should have a final TCP Layer. -func (conn *TCPIPv4) ConnectWithOptions(options []byte) { - conn.t.Helper() +func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { + t.Helper() // Send the SYN. - conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { - conn.t.Fatalf("didn't get synack during handshake: %s", err) + t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // ExpectNextData attempts to receive the next incoming segment for the @@ -706,9 +729,11 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio // It differs from ExpectData() in that here we are only interested in the next // received segment, while ExpectData() can receive multiple segments for the // connection until there is a match with given layers or a timeout. -func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + // Receive the first incoming TCP segment for this connection. - got, err := conn.ExpectData(&TCP{}, nil, timeout) + got, err := conn.ExpectData(t, &TCP{}, nil, timeout) if err != nil { return nil, err } @@ -717,7 +742,7 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) - tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length())) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length())) } if !(*Connection)(conn).match(expected, got) { return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) @@ -727,71 +752,91 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur // Send a packet with reasonable defaults. Potentially override the TCP layer in // the connection with the provided layer and add additionLayers. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&tcp}, additionalLayers...) +func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...) } // Close frees associated resources held by the TCPIPv4 connection. -func (conn *TCPIPv4) Close() { - (*Connection)(conn).Close() +func (conn *TCPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Expect expects a frame with the TCP layer matching the provided TCP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { - layer, err := (*Connection)(conn).Expect(&tcp, timeout) +func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &tcp, timeout) if layer == nil { return nil, err } gotTCP, ok := layer.(*TCP) if !ok { - conn.t.Fatalf("expected %s to be TCP", layer) + t.Fatalf("expected %s to be TCP", layer) } return gotTCP, err } -func (conn *TCPIPv4) tcpState() *tcpState { +func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState { + t.Helper() + state, ok := conn.layerStates[2].(*tcpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) } return state } -func (conn *TCPIPv4) ipv4State() *ipv4State { +func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv4State) if !ok { - conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) } return state } // RemoteSeqNum returns the next expected sequence number from the DUT. -func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { - return conn.tcpState().remoteSeqNum +func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).remoteSeqNum } // LocalSeqNum returns the next sequence number to send from the testbench. -func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { - return conn.tcpState().localSeqNum +func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).localSeqNum } // SynAck returns the SynAck that was part of the handshake. -func (conn *TCPIPv4) SynAck() *TCP { - return conn.tcpState().synAck +func (conn *TCPIPv4) SynAck(t *testing.T) *TCP { + t.Helper() + + return conn.tcpState(t).synAck } // LocalAddr gets the local socket address of this connection. -func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 { - sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) +func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) return sa } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *TCPIPv4) Drain() { - conn.sniffer.Drain() +func (conn *TCPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // IPv6Conn maintains the state for all the layers in a IPv6 connection. @@ -799,6 +844,8 @@ type IPv6Conn Connection // NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make EtherState: %s", err) @@ -821,25 +868,30 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { layerStates: []layerState{etherState, ipv6State}, injector: injector, sniffer: sniffer, - t: t, } } // Send sends a frame with ipv6 overriding the IPv6 layer defaults and // additionalLayers added after it. -func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...) +func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...) } // Close to clean up any resources held. -func (conn *IPv6Conn) Close() { - (*Connection)(conn).Close() +func (conn *IPv6Conn) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) { - return (*Connection)(conn).ExpectFrame(frame, timeout) +func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) } // UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. @@ -847,6 +899,8 @@ type UDPIPv4 Connection // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -872,79 +926,280 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { layerStates: []layerState{etherState, ipv4State, udpState}, injector: injector, sniffer: sniffer, - t: t, } } -func (conn *UDPIPv4) udpState() *udpState { +func (conn *UDPIPv4) udpState(t *testing.T) *udpState { + t.Helper() + state, ok := conn.layerStates[2].(*udpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) } return state } -func (conn *UDPIPv4) ipv4State() *ipv4State { +func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv4State) if !ok { - conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) } return state } // LocalAddr gets the local socket address of this connection. -func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 { - sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) +func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) return sa } // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. -func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&udp}, additionalLayers...) +func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) } // SendIP sends a packet with reasonable defaults, potentially overriding the // UDP and IPv4 headers and adding additionLayers. -func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...) +func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) { - conn.t.Helper() - layer, err := (*Connection)(conn).Expect(&udp, timeout) - if layer == nil { +func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) + if err != nil { return nil, err } gotUDP, ok := layer.(*UDP) if !ok { - conn.t.Fatalf("expected %s to be UDP", layer) + t.Fatalf("expected %s to be UDP", layer) } - return gotUDP, err + return gotUDP, nil } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) { - conn.t.Helper() +func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = &udp if payload.length() != 0 { expected = append(expected, &payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // Close frees associated resources held by the UDPIPv4 connection. -func (conn *UDPIPv4) Close() { - (*Connection)(conn).Close() +func (conn *UDPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection. +type UDPIPv6 Connection + +// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults. +func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make IPv6State: %s", err) + } + udpState, err := newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + return UDPIPv6{ + layerStates: []layerState{etherState, ipv6State, udpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *UDPIPv6) udpState(t *testing.T) *udpState { + t.Helper() + + state, ok := conn.layerStates[2].(*udpState) + if !ok { + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + } + return state +} + +func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State { + t.Helper() + + state, ok := conn.layerStates[1].(*ipv6State) + if !ok { + t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) + } + return state +} + +// LocalAddr gets the local socket address of this connection. +func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 { + t.Helper() + + sa := &unix.SockaddrInet6{ + Port: int(*conn.udpState(t).out.SrcPort), + // Local address is in perspective to the remote host, so it's scoped to the + // ID of the remote interface. + ZoneId: uint32(RemoteInterfaceID), + } + copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr) + return sa +} + +// Send sends a packet with reasonable defaults, potentially overriding the UDP +// layer and adding additionLayers. +func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) +} + +// SendIPv6 sends a packet with reasonable defaults, potentially overriding the +// UDP and IPv6 headers and adding additionLayers. +func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) +} + +// Expect expects a frame with the UDP layer matching the provided UDP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) + if err != nil { + return nil, err + } + gotUDP, ok := layer.(*UDP) + if !ok { + t.Fatalf("expected %s to be UDP", layer) + } + return gotUDP, nil +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = &udp + if payload.length() != 0 { + expected = append(expected, &payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the UDPIPv6 connection. +func (conn *UDPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *UDPIPv4) Drain() { - conn.sniffer.Drain() +func (conn *UDPIPv6) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. +type TCPIPv6 Connection + +// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults. +func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make ipv6State: %s", err) + } + tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv6{ + layerStates: []layerState{etherState, ipv6State, tcpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *TCPIPv6) SrcPort() uint16 { + state := conn.layerStates[2].(*tcpState) + return *state.out.SrcPort +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the TCPIPv6 connection. +func (conn *TCPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 2a2afecb5..6165ab293 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -16,11 +16,13 @@ package testbench import ( "context" + "encoding/binary" "flag" "net" "strconv" "syscall" "testing" + "time" pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" @@ -31,13 +33,14 @@ import ( // DUT communicates with the DUT to force it to make POSIX calls. type DUT struct { - t *testing.T conn *grpc.ClientConn posixServer POSIXClient } // NewDUT creates a new connection with the DUT over gRPC. func NewDUT(t *testing.T) DUT { + t.Helper() + flag.Parse() if err := genPseudoFlags(); err != nil { t.Fatal("generating psuedo flags:", err) @@ -50,7 +53,6 @@ func NewDUT(t *testing.T) DUT { } posixServer := NewPOSIXClient(conn) return DUT{ - t: t, conn: conn, posixServer: posixServer, } @@ -61,8 +63,9 @@ func (dut *DUT) TearDown() { dut.conn.Close() } -func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { - dut.t.Helper() +func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr { + t.Helper() + switch s := sa.(type) { case *unix.SockaddrInet4: return &pb.Sockaddr{ @@ -87,12 +90,13 @@ func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { }, } } - dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + t.Fatalf("can't parse Sockaddr struct: %+v", sa) return nil } -func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { - dut.t.Helper() +func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr { + t.Helper() + switch s := sa.Sockaddr.(type) { case *pb.Sockaddr_In: ret := unix.SockaddrInet4{ @@ -106,31 +110,34 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { ZoneId: s.In6.GetScopeId(), } copy(ret.Addr[:], s.In6.GetAddr()) + return &ret } - dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + t.Fatalf("can't parse Sockaddr proto: %#v", sa) return nil } // CreateBoundSocket makes a new socket on the DUT, with type typ and protocol // proto, and bound to the IP address addr. Returns the new file descriptor and // the port that was selected on the DUT. -func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { - dut.t.Helper() +func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) { + t.Helper() + var fd int32 if addr.To4() != nil { - fd = dut.Socket(unix.AF_INET, typ, proto) + fd = dut.Socket(t, unix.AF_INET, typ, proto) sa := unix.SockaddrInet4{} copy(sa.Addr[:], addr.To4()) - dut.Bind(fd, &sa) + dut.Bind(t, fd, &sa) } else if addr.To16() != nil { - fd = dut.Socket(unix.AF_INET6, typ, proto) + fd = dut.Socket(t, unix.AF_INET6, typ, proto) sa := unix.SockaddrInet6{} copy(sa.Addr[:], addr.To16()) - dut.Bind(fd, &sa) + sa.ZoneId = uint32(RemoteInterfaceID) + dut.Bind(t, fd, &sa) } else { - dut.t.Fatalf("unknown ip addr type for remoteIP") + t.Fatalf("invalid IP address: %s", addr) } - sa := dut.GetSockName(fd) + sa := dut.GetSockName(t, fd) var port int switch s := sa.(type) { case *unix.SockaddrInet4: @@ -138,15 +145,17 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) case *unix.SockaddrInet6: port = s.Port default: - dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + t.Fatalf("unknown sockaddr type from getsockname: %T", sa) } return fd, uint16(port) } // CreateListener makes a new TCP connection. If it fails, the test ends. -func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { - fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4)) - dut.Listen(fd, backlog) +func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) { + t.Helper() + + fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4)) + dut.Listen(t, fd, backlog) return fd, remotePort } @@ -156,53 +165,57 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { // Accept calls accept on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // AcceptWithErrno. -func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) + fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd) if fd < 0 { - dut.t.Fatalf("failed to accept: %s", err) + t.Fatalf("failed to accept: %s", err) } return fd, sa } // AcceptWithErrno calls accept on the DUT. -func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { - dut.t.Helper() +func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + req := pb.AcceptRequest{ Sockfd: sockfd, } resp, err := dut.posixServer.Accept(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Accept: %s", err) + t.Fatalf("failed to call Accept: %s", err) } - return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } // Bind calls bind on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is // needed, use BindWithErrno. -func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.BindWithErrno(ctx, fd, sa) + ret, err := dut.BindWithErrno(ctx, t, fd, sa) if ret != 0 { - dut.t.Fatalf("failed to bind socket: %s", err) + t.Fatalf("failed to bind socket: %s", err) } } // BindWithErrno calls bind on the DUT. -func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.BindRequest{ Sockfd: fd, - Addr: dut.sockaddrToProto(sa), + Addr: dut.sockaddrToProto(t, sa), } resp, err := dut.posixServer.Bind(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Bind: %s", err) + t.Fatalf("failed to call Bind: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -210,25 +223,27 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) ( // Close calls close on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // CloseWithErrno. -func (dut *DUT) Close(fd int32) { - dut.t.Helper() +func (dut *DUT) Close(t *testing.T, fd int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.CloseWithErrno(ctx, fd) + ret, err := dut.CloseWithErrno(ctx, t, fd) if ret != 0 { - dut.t.Fatalf("failed to close: %s", err) + t.Fatalf("failed to close: %s", err) } } // CloseWithErrno calls close on the DUT. -func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) { + t.Helper() + req := pb.CloseRequest{ Fd: fd, } resp, err := dut.posixServer.Close(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Close: %s", err) + t.Fatalf("failed to call Close: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -236,28 +251,30 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { // Connect calls connect on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use ConnectWithErrno. -func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.ConnectWithErrno(ctx, fd, sa) + ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) // Ignore 'operation in progress' error that can be returned when the socket // is non-blocking. if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { - dut.t.Fatalf("failed to connect socket: %s", err) + t.Fatalf("failed to connect socket: %s", err) } } // ConnectWithErrno calls bind on the DUT. -func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.ConnectRequest{ Sockfd: fd, - Addr: dut.sockaddrToProto(sa), + Addr: dut.sockaddrToProto(t, sa), } resp, err := dut.posixServer.Connect(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Connect: %s", err) + t.Fatalf("failed to call Connect: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -265,20 +282,22 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr // Fcntl calls fcntl on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use FcntlWithErrno. -func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 { - dut.t.Helper() +func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg) + ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg) if ret == -1 { - dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) + t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) } return ret } // FcntlWithErrno calls fcntl on the DUT. -func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) { + t.Helper() + req := pb.FcntlRequest{ Fd: fd, Cmd: cmd, @@ -286,7 +305,7 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, } resp, err := dut.posixServer.Fcntl(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Fcntl: %s", err) + t.Fatalf("failed to call Fcntl: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -294,32 +313,35 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, // GetSockName calls getsockname on the DUT and causes a fatal test failure if // it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockNameWithErrno. -func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { - dut.t.Helper() +func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) + ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd) if ret != 0 { - dut.t.Fatalf("failed to getsockname: %s", err) + t.Fatalf("failed to getsockname: %s", err) } return sa } // GetSockNameWithErrno calls getsockname on the DUT. -func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { - dut.t.Helper() +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + req := pb.GetSockNameRequest{ Sockfd: sockfd, } resp, err := dut.posixServer.GetSockName(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Bind: %s", err) + t.Fatalf("failed to call Bind: %s", err) } - return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } -func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { - dut.t.Helper() +func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { + t.Helper() + req := pb.GetSockOptRequest{ Sockfd: sockfd, Level: level, @@ -329,11 +351,11 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i } resp, err := dut.posixServer.GetSockOpt(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call GetSockOpt: %s", err) + t.Fatalf("failed to call GetSockOpt: %s", err) } optval := resp.GetOptval() if optval == nil { - dut.t.Fatalf("GetSockOpt response does not contain a value") + t.Fatalf("GetSockOpt response does not contain a value") } return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) } @@ -343,13 +365,14 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i // needed, use GetSockOptWithErrno. Because endianess and the width of values // might differ between the testbench and DUT architectures, prefer to use a // more specific GetSockOptXxx function. -func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { - dut.t.Helper() +func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen) + ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen) if ret != 0 { - dut.t.Fatalf("failed to GetSockOpt: %s", err) + t.Fatalf("failed to GetSockOpt: %s", err) } return optval } @@ -357,12 +380,13 @@ func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { // GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the // width of values might differ between the testbench and DUT architectures, // prefer to use a more specific GetSockOptXxxWithErrno function. -func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) +func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval) + t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val) } return ret, bytesval.Bytesval, errno } @@ -370,24 +394,26 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, // GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the int optval or error handling // is needed, use GetSockOptIntWithErrno. -func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 { - dut.t.Helper() +func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname) + ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname) if ret != 0 { - dut.t.Fatalf("failed to GetSockOptInt: %s", err) + t.Fatalf("failed to GetSockOptInt: %s", err) } return intval } // GetSockOptIntWithErrno calls getsockopt with an integer optval. -func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) +func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) intval, ok := optval.Val.(*pb.SockOptVal_Intval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval) + t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val) } return ret, intval.Intval, errno } @@ -395,24 +421,26 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna // GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockOptTimevalWithErrno. -func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval { - dut.t.Helper() +func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname) + ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname) if ret != 0 { - dut.t.Fatalf("failed to GetSockOptTimeval: %s", err) + t.Fatalf("failed to GetSockOptTimeval: %s", err) } return timeval } // GetSockOptTimevalWithErrno calls getsockopt and returns a timeval. -func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) +func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) tv, ok := optval.Val.(*pb.SockOptVal_Timeval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval) + t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val) } timeval := unix.Timeval{ Sec: tv.Timeval.Seconds, @@ -424,26 +452,28 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o // Listen calls listen on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ListenWithErrno. -func (dut *DUT) Listen(sockfd, backlog int32) { - dut.t.Helper() +func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) + ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog) if ret != 0 { - dut.t.Fatalf("failed to listen: %s", err) + t.Fatalf("failed to listen: %s", err) } } // ListenWithErrno calls listen on the DUT. -func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) { + t.Helper() + req := pb.ListenRequest{ Sockfd: sockfd, Backlog: backlog, } resp, err := dut.posixServer.Listen(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Listen: %s", err) + t.Fatalf("failed to call Listen: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -451,20 +481,22 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int // Send calls send on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendWithErrno. -func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { - dut.t.Helper() +func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) + ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags) if ret == -1 { - dut.t.Fatalf("failed to send: %s", err) + t.Fatalf("failed to send: %s", err) } return ret } // SendWithErrno calls send on the DUT. -func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) { + t.Helper() + req := pb.SendRequest{ Sockfd: sockfd, Buf: buf, @@ -472,7 +504,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla } resp, err := dut.posixServer.Send(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Send: %s", err) + t.Fatalf("failed to call Send: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -480,48 +512,52 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla // SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendToWithErrno. -func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { - dut.t.Helper() +func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr) + ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr) if ret == -1 { - dut.t.Fatalf("failed to sendto: %s", err) + t.Fatalf("failed to sendto: %s", err) } return ret } // SendToWithErrno calls sendto on the DUT. -func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.SendToRequest{ Sockfd: sockfd, Buf: buf, Flags: flags, - DestAddr: dut.sockaddrToProto(destAddr), + DestAddr: dut.sockaddrToProto(t, destAddr), } resp, err := dut.posixServer.SendTo(ctx, &req) if err != nil { - dut.t.Fatalf("faled to call SendTo: %s", err) + t.Fatalf("faled to call SendTo: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } // SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking // is true, otherwise it will clear the flag. -func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) { - dut.t.Helper() - flags := dut.Fcntl(fd, unix.F_GETFL, 0) +func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { + t.Helper() + + flags := dut.Fcntl(t, fd, unix.F_GETFL, 0) if nonblocking { flags |= unix.O_NONBLOCK } else { flags &= ^unix.O_NONBLOCK } - dut.Fcntl(fd, unix.F_SETFL, flags) + dut.Fcntl(t, fd, unix.F_SETFL, flags) } -func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { - dut.t.Helper() +func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { + t.Helper() + req := pb.SetSockOptRequest{ Sockfd: sockfd, Level: level, @@ -530,7 +566,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op } resp, err := dut.posixServer.SetSockOpt(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call SetSockOpt: %s", err) + t.Fatalf("failed to call SetSockOpt: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -540,81 +576,89 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op // needed, use SetSockOptWithErrno. Because endianess and the width of values // might differ between the testbench and DUT architectures, prefer to use a // more specific SetSockOptXxx function. -func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { - dut.t.Helper() +func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) + ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOpt: %s", err) + t.Fatalf("failed to SetSockOpt: %s", err) } } // SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the // width of values might differ between the testbench and DUT architectures, // prefer to use a more specific SetSockOptXxxWithErrno function. -func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) { - dut.t.Helper() - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) } // SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the int optval or error handling // is needed, use SetSockOptIntWithErrno. -func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { - dut.t.Helper() +func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) + ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOptInt: %s", err) + t.Fatalf("failed to SetSockOptInt: %s", err) } } // SetSockOptIntWithErrno calls setsockopt with an integer optval. -func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) { - dut.t.Helper() - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) } // SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the timeout or error handling is // needed, use SetSockOptTimevalWithErrno. -func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { - dut.t.Helper() +func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) + ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv) if ret != 0 { - dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) + t.Fatalf("failed to SetSockOptTimeval: %s", err) } } // SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to // bytes. -func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { - dut.t.Helper() +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { + t.Helper() + timeval := pb.Timeval{ Seconds: int64(tv.Sec), Microseconds: int64(tv.Usec), } - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) } // Socket calls socket on the DUT and returns the file descriptor. If socket // fails on the DUT, the test ends. -func (dut *DUT) Socket(domain, typ, proto int32) int32 { - dut.t.Helper() - fd, err := dut.SocketWithErrno(domain, typ, proto) +func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 { + t.Helper() + + fd, err := dut.SocketWithErrno(t, domain, typ, proto) if fd < 0 { - dut.t.Fatalf("failed to create socket: %s", err) + t.Fatalf("failed to create socket: %s", err) } return fd } // SocketWithErrno calls socket on the DUT and returns the fd and errno. -func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) { + t.Helper() + req := pb.SocketRequest{ Domain: domain, Type: typ, @@ -623,7 +667,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { ctx := context.Background() resp, err := dut.posixServer.Socket(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Socket: %s", err) + t.Fatalf("failed to call Socket: %s", err) } return resp.GetFd(), syscall.Errno(resp.GetErrno_()) } @@ -631,20 +675,22 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { // Recv calls recv on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // RecvWithErrno. -func (dut *DUT) Recv(sockfd, len, flags int32) []byte { - dut.t.Helper() +func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags) if ret == -1 { - dut.t.Fatalf("failed to recv: %s", err) + t.Fatalf("failed to recv: %s", err) } return buf } // RecvWithErrno calls recv on the DUT. -func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { - dut.t.Helper() +func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) { + t.Helper() + req := pb.RecvRequest{ Sockfd: sockfd, Len: len, @@ -652,7 +698,47 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (in } resp, err := dut.posixServer.Recv(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Recv: %s", err) + t.Fatalf("failed to call Recv: %s", err) } return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } + +// SetSockLingerOption sets SO_LINGER socket option on the DUT. +func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Duration, enable bool) { + var linger unix.Linger + if enable { + linger.Onoff = 1 + } + linger.Linger = int32(timeout / time.Second) + + buf := make([]byte, 8) + binary.LittleEndian.PutUint32(buf, uint32(linger.Onoff)) + binary.LittleEndian.PutUint32(buf[4:], uint32(linger.Linger)) + dut.SetSockOpt(t, sockfd, unix.SOL_SOCKET, unix.SO_LINGER, buf) +} + +// Shutdown calls shutdown on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ShutdownWithErrno. +func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + return dut.ShutdownWithErrno(ctx, t, fd, how) +} + +// ShutdownWithErrno calls shutdown on the DUT. +func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) error { + t.Helper() + + req := pb.ShutdownRequest{ + Fd: fd, + How: how, + } + resp, err := dut.posixServer.Shutdown(ctx, &req) + if err != nil { + t.Fatalf("failed to call Shutdown: %s", err) + } + return syscall.Errno(resp.GetErrno_()) +} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index a8121b0da..a35562ca8 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -15,6 +15,7 @@ package testbench import ( + "encoding/binary" "encoding/hex" "fmt" "reflect" @@ -470,21 +471,11 @@ func (l *IPv6) ToBytes() ([]byte, error) { if l.NextHeader != nil { fields.NextHeader = *l.NextHeader } else { - switch n := l.next().(type) { - case *TCP: - fields.NextHeader = uint8(header.TCPProtocolNumber) - case *UDP: - fields.NextHeader = uint8(header.UDPProtocolNumber) - case *ICMPv6: - fields.NextHeader = uint8(header.ICMPv6ProtocolNumber) - case *IPv6HopByHopOptionsExtHdr: - fields.NextHeader = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) - case *IPv6DestinationOptionsExtHdr: - fields.NextHeader = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) - default: - // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("ToBytes can't deduce the IPv6 header's next protocol: %#v", n) + nh, err := nextHeaderByLayer(l.next()) + if err != nil { + return nil, err } + fields.NextHeader = nh } if l.HopLimit != nil { fields.HopLimit = *l.HopLimit @@ -514,6 +505,8 @@ func nextIPv6PayloadParser(nextHeader uint8) layerParser { return parseIPv6HopByHopOptionsExtHdr case header.IPv6DestinationOptionsExtHdrIdentifier: return parseIPv6DestinationOptionsExtHdr + case header.IPv6FragmentExtHdrIdentifier: + return parseIPv6FragmentExtHdr } return parsePayload } @@ -566,14 +559,56 @@ type IPv6DestinationOptionsExtHdr struct { Options []byte } +// IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header. +type IPv6FragmentExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + FragmentOffset *uint16 + MoreFragments *bool + Identification *uint32 +} + +// nextHeaderByLayer finds the correct next header protocol value for layer l. +func nextHeaderByLayer(l Layer) (uint8, error) { + if l == nil { + return uint8(header.IPv6NoNextHeaderIdentifier), nil + } + switch l.(type) { + case *TCP: + return uint8(header.TCPProtocolNumber), nil + case *UDP: + return uint8(header.UDPProtocolNumber), nil + case *ICMPv6: + return uint8(header.ICMPv6ProtocolNumber), nil + case *Payload: + return uint8(header.IPv6NoNextHeaderIdentifier), nil + case *IPv6HopByHopOptionsExtHdr: + return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil + case *IPv6DestinationOptionsExtHdr: + return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil + case *IPv6FragmentExtHdr: + return uint8(header.IPv6FragmentExtHdrIdentifier), nil + default: + // TODO(b/161005083): Support more protocols as needed. + return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l) + } +} + // ipv6OptionsExtHdrToBytes serializes an options extension header into bytes. -func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, options []byte) []byte { +func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) { length := len(options) + 2 + if length%8 != 0 { + return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options)) + } bytes := make([]byte, length) - if nextHeader == nil { - bytes[0] = byte(header.IPv6NoNextHeaderIdentifier) - } else { + if nextHeader != nil { bytes[0] = byte(*nextHeader) + } else { + nh, err := nextHeaderByLayer(nextLayer) + if err != nil { + return nil, err + } + bytes[0] = nh } // ExtHdrLen field is the length of the extension header // in 8-octet unit, ignoring the first 8 octets. @@ -581,7 +616,7 @@ func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, // https://tools.ietf.org/html/rfc2460#section-4.6 bytes[1] = uint8((length - 8) / 8) copy(bytes[2:], options) - return bytes + return bytes, nil } // IPv6ExtHdrIdent is a helper routine that allocates a new @@ -591,14 +626,45 @@ func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6Extens return &id } -// ToBytes implements Layer.ToBytes +// ToBytes implements Layer.ToBytes. func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) { - return ipv6OptionsExtHdrToBytes(l.NextHeader, l.Options), nil + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) } -// ToBytes implements Layer.ToBytes +// ToBytes implements Layer.ToBytes. func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) { - return ipv6OptionsExtHdrToBytes(l.NextHeader, l.Options), nil + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) { + var offset, mflag uint16 + var ident uint32 + bytes := make([]byte, header.IPv6FragmentExtHdrLength) + if l.NextHeader != nil { + bytes[0] = byte(*l.NextHeader) + } else { + nh, err := nextHeaderByLayer(l.next()) + if err != nil { + return nil, err + } + bytes[0] = nh + } + bytes[1] = 0 // reserved + if l.MoreFragments != nil && *l.MoreFragments { + mflag = 1 + } + if l.FragmentOffset != nil { + offset = *l.FragmentOffset + } + if l.Identification != nil { + ident = *l.Identification + } + offsetAndMflag := offset<<3 | mflag + binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag) + binary.BigEndian.PutUint32(bytes[4:], ident) + + return bytes, nil } // parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader @@ -631,6 +697,26 @@ func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) { return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser } +// Bool is a helper routine that allocates a new +// bool value to store v and returns a pointer to it. +func Bool(v bool) *bool { + return &v +} + +// parseIPv6FragmentExtHdr parses the bytes assuming that they start +// with an IPv6 Fragment Extension Header. +func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) { + nextHeader := b[0] + var extHdr header.IPv6FragmentExtHdr + copy(extHdr[:], b[2:]) + return &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)), + FragmentOffset: Uint16(extHdr.FragmentOffset()), + MoreFragments: Bool(extHdr.More()), + Identification: Uint32(extHdr.ID()), + }, nextIPv6PayloadParser(nextHeader) +} + func (l *IPv6HopByHopOptionsExtHdr) length() int { return len(l.Options) + 2 } @@ -667,13 +753,31 @@ func (l *IPv6DestinationOptionsExtHdr) String() string { return stringLayer(l) } +func (*IPv6FragmentExtHdr) length() int { + return header.IPv6FragmentExtHdrLength +} + +func (l *IPv6FragmentExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6FragmentExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6FragmentExtHdr) String() string { + return stringLayer(l) +} + // ICMPv6 can construct and match an ICMPv6 encapsulation. type ICMPv6 struct { LayerBase - Type *header.ICMPv6Type - Code *byte - Checksum *uint16 - NDPPayload []byte + Type *header.ICMPv6Type + Code *header.ICMPv6Code + Checksum *uint16 + Payload []byte } func (l *ICMPv6) String() string { @@ -684,7 +788,7 @@ func (l *ICMPv6) String() string { // ToBytes implements Layer.ToBytes. func (l *ICMPv6) ToBytes() ([]byte, error) { - b := make([]byte, header.ICMPv6HeaderSize+len(l.NDPPayload)) + b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload)) h := header.ICMPv6(b) if l.Type != nil { h.SetType(*l.Type) @@ -692,7 +796,7 @@ func (l *ICMPv6) ToBytes() ([]byte, error) { if l.Code != nil { h.SetCode(*l.Code) } - copy(h.NDPPayload(), l.NDPPayload) + copy(h.NDPPayload(), l.Payload) if l.Checksum != nil { h.SetChecksum(*l.Checksum) } else { @@ -701,7 +805,11 @@ func (l *ICMPv6) ToBytes() ([]byte, error) { // We need to search forward to find the IPv6 header. for prev := l.Prev(); prev != nil; prev = prev.Prev() { if ipv6, ok := prev.(*IPv6); ok { - h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{})) + payload, err := payload(l) + if err != nil { + return nil, err + } + h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload)) break } } @@ -715,6 +823,12 @@ func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type { return &v } +// ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store +// v and returns a pointer to it. +func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code { + return &v +} + // Byte is a helper routine that allocates a new byte value to store // v and returns a pointer to it. func Byte(v byte) *byte { @@ -725,10 +839,10 @@ func Byte(v byte) *byte { func parseICMPv6(b []byte) (Layer, layerParser) { h := header.ICMPv6(b) icmpv6 := ICMPv6{ - Type: ICMPv6Type(h.Type()), - Code: Byte(h.Code()), - Checksum: Uint16(h.Checksum()), - NDPPayload: h.NDPPayload(), + Type: ICMPv6Type(h.Type()), + Code: ICMPv6Code(h.Code()), + Checksum: Uint16(h.Checksum()), + Payload: h.NDPPayload(), } return &icmpv6, nil } @@ -738,7 +852,7 @@ func (l *ICMPv6) match(other Layer) bool { } func (l *ICMPv6) length() int { - return header.ICMPv6HeaderSize + len(l.NDPPayload) + return header.ICMPv6HeaderSize + len(l.Payload) } // merge overrides the values in l with the values from other but only in fields @@ -753,11 +867,17 @@ func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type { return &t } +// ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value +// to store t and returns a pointer to it. +func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code { + return &t +} + // ICMPv4 can construct and match an ICMPv4 encapsulation. type ICMPv4 struct { LayerBase Type *header.ICMPv4Type - Code *uint8 + Code *header.ICMPv4Code Checksum *uint16 } @@ -773,7 +893,7 @@ func (l *ICMPv4) ToBytes() ([]byte, error) { h.SetType(*l.Type) } if l.Code != nil { - h.SetCode(byte(*l.Code)) + h.SetCode(*l.Code) } if l.Checksum != nil { h.SetChecksum(*l.Checksum) @@ -793,7 +913,7 @@ func parseICMPv4(b []byte) (Layer, layerParser) { h := header.ICMPv4(b) icmpv4 := ICMPv4{ Type: ICMPv4Type(h.Type()), - Code: Uint8(h.Code()), + Code: ICMPv4Code(h.Code()), Checksum: Uint16(h.Checksum()), } return &icmpv4, parsePayload @@ -904,12 +1024,14 @@ func payload(l Layer) (buffer.VectorisedView, error) { func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { totalLength := uint16(totalLength(l)) var xsum uint16 - switch s := l.Prev().(type) { + switch p := l.Prev().(type) { case *IPv4: - xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength) + xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) + case *IPv6: + xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) default: - // TODO(b/150301488): Support more protocols, like IPv6. - return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s) + // TODO(b/161246171): Support more protocols. + return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p) } payloadBytes, err := payload(l) if err != nil { diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index 382a983a1..eca0780b5 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -593,10 +593,107 @@ func TestIPv6ExtHdrOptions(t *testing.T) { Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, }, &ICMPv6{ - Type: ICMPv6Type(header.ICMPv6ParamProblem), - Code: Byte(0), - Checksum: Uint16(0x5f98), - NDPPayload: []byte{0x00, 0x00, 0x00, 0x06}, + Type: ICMPv6Type(header.ICMPv6ParamProblem), + Code: ICMPv6Code(header.ICMPv6ErroneousHeader), + Checksum: Uint16(0x5f98), + Payload: []byte{0x00, 0x00, 0x00, 0x06}, + }, + }, + }, + { + description: "IPv6/HopByHop/Fragment", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, 0x2a, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(false), + Identification: Uint32(42), + }, + &Payload{ + Bytes: nil, + }, + }, + }, + { + description: "IPv6/DestOpt/Fragment/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x1b, 0x3c, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // Destination Options + 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6DestinationOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(true), + Identification: Uint32(42), + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + { + description: "IPv6/Fragment/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x2c, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(true), + Identification: Uint32(42), + }, + &Payload{ + Bytes: []byte("Sample Data"), }, }, }, @@ -606,6 +703,19 @@ func TestIPv6ExtHdrOptions(t *testing.T) { if !layers.match(tt.wantLayers) { t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) } + // Make sure we can generate correct next header values and checksums + for _, layer := range layers { + switch layer := layer.(type) { + case *IPv6HopByHopOptionsExtHdr: + layer.NextHeader = nil + case *IPv6DestinationOptionsExtHdr: + layer.NextHeader = nil + case *IPv6FragmentExtHdr: + layer.NextHeader = nil + case *ICMPv6: + layer.Checksum = nil + } + } gotBytes, err := layers.ToBytes() if err != nil { t.Fatalf("ToBytes() failed on %s: %s", &layers, err) diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 278229b7e..193bb2dc8 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -28,7 +28,6 @@ import ( // Sniffer can sniff raw packets on the wire. type Sniffer struct { - t *testing.T fd int } @@ -40,6 +39,8 @@ func htons(x uint16) uint16 { // NewSniffer creates a Sniffer connected to *device. func NewSniffer(t *testing.T) (Sniffer, error) { + t.Helper() + snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) if err != nil { return Sniffer{}, err @@ -51,7 +52,6 @@ func NewSniffer(t *testing.T) (Sniffer, error) { t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) } return Sniffer{ - t: t, fd: snifferFd, }, nil } @@ -61,7 +61,9 @@ func NewSniffer(t *testing.T) (Sniffer, error) { const maxReadSize int = 65536 // Recv tries to read one frame until the timeout is up. -func (s *Sniffer) Recv(timeout time.Duration) []byte { +func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte { + t.Helper() + deadline := time.Now().Add(timeout) for { timeout = deadline.Sub(time.Now()) @@ -75,7 +77,7 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { - s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) + t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) } buf := make([]byte, maxReadSize) @@ -85,10 +87,10 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { continue } if err != nil { - s.t.Fatalf("can't read: %s", err) + t.Fatalf("can't read: %s", err) } if nread > maxReadSize { - s.t.Fatalf("received a truncated frame of %d bytes", nread) + t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize) } return buf[:nread] } @@ -96,14 +98,16 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { // Drain drains the Sniffer's socket receive buffer by receiving until there's // nothing else to receive. -func (s *Sniffer) Drain() { - s.t.Helper() +func (s *Sniffer) Drain(t *testing.T) { + t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) if err != nil { - s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + t.Fatalf("failed to get sniffer socket fd flags: %s", err) } - if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { - s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + nonBlockingFlags := flags | unix.O_NONBLOCK + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil { + t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err) } for { buf := make([]byte, maxReadSize) @@ -113,7 +117,7 @@ func (s *Sniffer) Drain() { } } if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { - s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err) } } @@ -128,13 +132,14 @@ func (s *Sniffer) close() error { // Injector can inject raw frames. type Injector struct { - t *testing.T fd int } // NewInjector creates a new injector on *device. func NewInjector(t *testing.T) (Injector, error) { - ifInfo, err := net.InterfaceByName(Device) + t.Helper() + + ifInfo, err := net.InterfaceByName(LocalDevice) if err != nil { return Injector{}, err } @@ -156,15 +161,20 @@ func NewInjector(t *testing.T) (Injector, error) { return Injector{}, err } return Injector{ - t: t, fd: injectFd, }, nil } // Send a raw frame. -func (i *Injector) Send(b []byte) { - if _, err := unix.Write(i.fd, b); err != nil { - i.t.Fatalf("can't write: %s of len %d", err, len(b)) +func (i *Injector) Send(t *testing.T, b []byte) { + t.Helper() + + n, err := unix.Write(i.fd, b) + if err != nil { + t.Fatalf("can't write bytes of len %d: %s", len(b), err) + } + if n != len(b) { + t.Fatalf("got %d bytes written, want %d", n, len(b)) } } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index d64f32a5b..0073a1361 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -27,27 +27,44 @@ import ( ) var ( - // DUTType is the type of device under test. - DUTType = "" - // Device is the local device on the test network. - Device = "" + // Native indicates that the test is being run natively. + Native = false + // LocalDevice is the device that testbench uses to inject traffic. + LocalDevice = "" + // RemoteDevice is the device name on the DUT, individual tests can + // use the name to construct tests. + RemoteDevice = "" + // LocalIPv4 is the local IPv4 address on the test network. LocalIPv4 = "" + // RemoteIPv4 is the DUT's IPv4 address on the test network. + RemoteIPv4 = "" + // IPv4PrefixLength is the network prefix length of the IPv4 test network. + IPv4PrefixLength = 0 + // LocalIPv6 is the local IPv6 address on the test network. LocalIPv6 = "" + // RemoteIPv6 is the DUT's IPv6 address on the test network. + RemoteIPv6 = "" + + // LocalInterfaceID is the ID of the local interface on the test network. + LocalInterfaceID uint32 + // RemoteInterfaceID is the ID of the remote interface on the test network. + // + // Not using uint32 because package flag does not support uint32. + RemoteInterfaceID uint64 + // LocalMAC is the local MAC address on the test network. LocalMAC = "" + // RemoteMAC is the DUT's MAC address on the test network. + RemoteMAC = "" + // POSIXServerIP is the POSIX server's IP address on the control network. POSIXServerIP = "" // POSIXServerPort is the UDP port the POSIX server is bound to on the // control network. POSIXServerPort = 40000 - // RemoteIPv4 is the DUT's IPv4 address on the test network. - RemoteIPv4 = "" - // RemoteIPv6 is the DUT's IPv6 address on the test network. - RemoteIPv6 = "" - // RemoteMAC is the DUT's MAC address on the test network. - RemoteMAC = "" + // RPCKeepalive is the gRPC keepalive. RPCKeepalive = 10 * time.Second // RPCTimeout is the gRPC timeout. @@ -66,8 +83,10 @@ func RegisterFlags(fs *flag.FlagSet) { fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") - fs.StringVar(&Device, "device", Device, "local device for test packets") - fs.StringVar(&DUTType, "dut_type", DUTType, "type of device under test") + fs.StringVar(&LocalDevice, "local_device", LocalDevice, "local device to inject traffic") + fs.StringVar(&RemoteDevice, "remote_device", RemoteDevice, "remote device on the DUT") + fs.BoolVar(&Native, "native", Native, "whether the test is running natively") + fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets") } // genPseudoFlags populates flag-like global config based on real flags. @@ -90,6 +109,13 @@ func genPseudoFlags() error { LocalMAC = deviceInfo.MAC.String() LocalIPv6 = deviceInfo.IPv6Addr.String() + LocalInterfaceID = deviceInfo.ID + + if deviceInfo.IPv4Net != nil { + IPv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size() + } else { + IPv4PrefixLength, _ = net.ParseIP(LocalIPv4).DefaultMask().Size() + } return nil } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 85749c559..94731c64b 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -18,8 +18,6 @@ packetimpact_go_test( packetimpact_go_test( name = "ipv4_id_uniqueness", srcs = ["ipv4_id_uniqueness_test.go"], - # TODO(b/157506701) Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/abi/linux", "//pkg/tcpip/header", @@ -29,14 +27,35 @@ packetimpact_go_test( ) packetimpact_go_test( - name = "udp_recv_multicast", - srcs = ["udp_recv_multicast_test.go"], - # TODO(b/152813495): Fix netstack then remove the line below. - expect_netstack_failure = True, + name = "udp_discard_mcast_source_addr", + srcs = ["udp_discard_mcast_source_addr_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_recv_mcast_bcast", + srcs = ["udp_recv_mcast_bcast_test.go"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_any_addr_recv_unicast", + srcs = ["udp_any_addr_recv_unicast_test.go"], + deps = [ + "//pkg/tcpip", + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -147,8 +166,8 @@ packetimpact_go_test( ) packetimpact_go_test( - name = "tcp_close_wait_ack", - srcs = ["tcp_close_wait_ack_test.go"], + name = "tcp_unacc_seq_ack", + srcs = ["tcp_unacc_seq_ack_test.go"], deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", @@ -211,6 +230,16 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "tcp_network_unreachable", + srcs = ["tcp_network_unreachable_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "tcp_cork_mss", srcs = ["tcp_cork_mss_test.go"], deps = [ @@ -231,6 +260,28 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "tcp_timewait_reset", + srcs = ["tcp_timewait_reset_test.go"], + # TODO(b/168523247): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_queue_send_in_syn_sent", + srcs = ["tcp_queue_send_in_syn_sent_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "icmpv6_param_problem", srcs = ["icmpv6_param_problem_test.go"], # TODO(b/153485026): Fix netstack then remove the line below. @@ -257,10 +308,45 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "ipv6_fragment_reassembly", + srcs = ["ipv6_fragment_reassembly_test.go"], + # TODO(b/160919104): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "udp_send_recv_dgram", srcs = ["udp_send_recv_dgram_test.go"], deps = [ "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_linger", + srcs = ["tcp_linger_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_rcv_buf_space", + srcs = ["tcp_rcv_buf_space_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 407565078..a61054c2c 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -39,34 +39,34 @@ func TestFinWait2Timeout(t *testing.T) { t.Run(tt.description, func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() + defer conn.Close(t) + conn.Connect(t) - acceptFd, _ := dut.Accept(listenFd) + acceptFd, _ := dut.Accept(t, listenFd) if tt.linger2 { tv := unix.Timeval{Sec: 1, Usec: 0} - dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) + dut.SetSockOptTimeval(t, acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) } - dut.Close(acceptFd) + dut.Close(t, acceptFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) time.Sleep(5 * time.Second) - conn.Drain() + conn.Drain(t) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) if tt.linger2 { - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { + if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { t.Fatalf("expected no RST packets within ten seconds but got one: %s", got) } } diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go index 4d1d9a7f5..2d59d552d 100644 --- a/test/packetimpact/tests/icmpv6_param_problem_test.go +++ b/test/packetimpact/tests/icmpv6_param_problem_test.go @@ -34,19 +34,19 @@ func TestICMPv6ParamProblemTest(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) - defer conn.Close() + defer conn.Close(t) ipv6 := testbench.IPv6{ // 254 is reserved and used for experimentation and testing. This should // cause an error. NextHeader: testbench.Uint8(254), } icmpv6 := testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - NDPPayload: []byte("hello world"), + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Payload: []byte("hello world"), } - toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6) - (*testbench.Connection)(&conn).SendFrame(toSend) + toSend := (*testbench.Connection)(&conn).CreateFrame(t, testbench.Layers{&ipv6}, &icmpv6) + (*testbench.Connection)(&conn).SendFrame(t, toSend) // Build the expected ICMPv6 payload, which includes an index to the // problematic byte and also the problematic packet as described in @@ -62,8 +62,8 @@ func TestICMPv6ParamProblemTest(t *testing.T) { binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset) expectedPayload = append(b, expectedPayload...) expectedICMPv6 := testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), - NDPPayload: expectedPayload, + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Payload: expectedPayload, } paramProblem := testbench.Layers{ @@ -72,7 +72,7 @@ func TestICMPv6ParamProblemTest(t *testing.T) { &expectedICMPv6, } timeout := time.Second - if _, err := conn.ExpectFrame(paramProblem, timeout); err != nil { + if _, err := conn.ExpectFrame(t, paramProblem, timeout); err != nil { t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err) } } diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go index 70f6df5e0..cf881418c 100644 --- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -31,8 +31,8 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { - layers, err := conn.ExpectData(expect, expectPayload, time.Second) +func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { + layers, err := conn.ExpectData(t, expect, expectPayload, time.Second) if err != nil { return 0, fmt.Errorf("failed to receive TCP segment: %s", err) } @@ -69,17 +69,17 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - remoteFD, _ := dut.Accept(listenFD) - defer dut.Close(remoteFD) + conn.Connect(t) + remoteFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, remoteFD) - dut.SetSockOptInt(remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) // TODO(b/129291778) The following socket option clears the DF bit on // IP packets sent over the socket, and is currently not supported by @@ -87,30 +87,30 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { // socket option being not supported does not affect the operation of // this test. Once the socket option is supported, the following call // can be changed to simply assert success. - ret, errno := dut.SetSockOptIntWithErrno(context.Background(), remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) + ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) if ret == -1 && errno != unix.ENOTSUP { t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno) } samplePayload := &testbench.Payload{Bytes: tc.payload} - dut.Send(remoteFD, tc.payload, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, remoteFD, tc.payload, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err) } // Let the DUT estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - dut.Send(remoteFD, tc.payload, 0) - expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))} - originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + dut.Send(t, remoteFD, tc.payload, 0) + expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))} + originalID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) if err != nil { t.Fatalf("failed to receive TCP segment: %s", err) } - retransmitID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + retransmitID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) if err != nil { t.Fatalf("failed to receive retransmitted TCP segment: %s", err) } diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go new file mode 100644 index 000000000..a24c85566 --- /dev/null +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -0,0 +1,168 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_fragment_reassembly_test + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "flag" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +const ( + // The payload length for the first fragment we send. This number + // is a multiple of 8 near 750 (half of 1500). + firstPayloadLength = 752 + // The ID field for our outgoing fragments. + fragmentID = 1 + // A node must be able to accept a fragmented packet that, + // after reassembly, is as large as 1500 octets. + reassemblyCap = 1500 +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestIPv6FragmentReassembly(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close(t) + + firstPayloadToSend := make([]byte, firstPayloadLength) + for i := range firstPayloadToSend { + firstPayloadToSend[i] = 'A' + } + + secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize + secondPayloadToSend := firstPayloadToSend[:secondPayloadLength] + + icmpv6EchoPayload := make([]byte, 4) + binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0) + binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0) + icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...) + + lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) + icmpv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), + Payload: icmpv6EchoPayload, + } + icmpv6Bytes, err := icmpv6.ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6: %s", err) + } + cksum := header.ICMPv6Checksum( + header.ICMPv6(icmpv6Bytes), + lIP, + rIP, + buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}), + ) + + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + FragmentOffset: testbench.Uint16(0), + MoreFragments: testbench.Bool(true), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), + Payload: icmpv6EchoPayload, + Checksum: &cksum, + }) + + icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber) + + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8), + MoreFragments: testbench.Bool(false), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.Payload{ + Bytes: secondPayloadToSend, + }) + + gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + FragmentOffset: testbench.Uint16(0), + MoreFragments: testbench.Bool(true), + }, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoReply), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), + }, + }, time.Second) + if err != nil { + t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err) + } + + id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification + gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6: %s", err) + } + icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:] + receivedLen := len(icmpPayload) + wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen + wantFirstPayload := make([]byte, receivedLen) + for i := range wantFirstPayload { + wantFirstPayload[i] = 'A' + } + wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen] + if !bytes.Equal(icmpPayload, wantFirstPayload) { + t.Fatalf("received unexpected payload, got: %s, want: %s", + hex.Dump(icmpPayload), + hex.Dump(wantFirstPayload)) + } + + gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)), + MoreFragments: testbench.Bool(false), + Identification: &id, + }, + &testbench.ICMPv6{}, + }, time.Second) + if err != nil { + t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err) + } + secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err) + } + if !bytes.Equal(secondPayload, wantSecondPayload) { + t.Fatalf("received unexpected payload, got: %s, want: %s", + hex.Dump(secondPayload), + hex.Dump(wantSecondPayload)) + } +} diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go index d301d8829..e79d74476 100644 --- a/test/packetimpact/tests/ipv6_unknown_options_action_test.go +++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go @@ -23,21 +23,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - tb "gvisor.dev/gvisor/test/packetimpact/testbench" + "gvisor.dev/gvisor/test/packetimpact/testbench" ) func init() { - tb.RegisterFlags(flag.CommandLine) + testbench.RegisterFlags(flag.CommandLine) } -func mkHopByHopOptionsExtHdr(optType byte) tb.Layer { - return &tb.IPv6HopByHopOptionsExtHdr{ +func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6HopByHopOptionsExtHdr{ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, } } -func mkDestinationOptionsExtHdr(optType byte) tb.Layer { - return &tb.IPv6DestinationOptionsExtHdr{ +func mkDestinationOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6DestinationOptionsExtHdr{ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, } } @@ -49,7 +49,7 @@ func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte { func TestIPv6UnknownOptionAction(t *testing.T) { for _, tt := range []struct { description string - mkExtHdr func(optType byte) tb.Layer + mkExtHdr func(optType byte) testbench.Layer action header.IPv6OptionUnknownAction multicastDst bool wantICMPv6 bool @@ -140,21 +140,21 @@ func TestIPv6UnknownOptionAction(t *testing.T) { }, } { t.Run(tt.description, func(t *testing.T) { - dut := tb.NewDUT(t) + dut := testbench.NewDUT(t) defer dut.TearDown() - ipv6Conn := tb.NewIPv6Conn(t, tb.IPv6{}, tb.IPv6{}) - conn := (*tb.Connection)(&ipv6Conn) - defer ipv6Conn.Close() + ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + conn := (*testbench.Connection)(&ipv6Conn) + defer ipv6Conn.Close(t) - outgoingOverride := tb.Layers{} + outgoingOverride := testbench.Layers{} if tt.multicastDst { - outgoingOverride = tb.Layers{&tb.IPv6{ - DstAddr: tb.Address(tcpip.Address(net.ParseIP("ff02::1"))), + outgoingOverride = testbench.Layers{&testbench.IPv6{ + DstAddr: testbench.Address(tcpip.Address(net.ParseIP("ff02::1"))), }} } - outgoing := conn.CreateFrame(outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) - conn.SendFrame(outgoing) + outgoing := conn.CreateFrame(t, outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) + conn.SendFrame(t, outgoing) ipv6Sent := outgoing[1:] invokingPacket, err := ipv6Sent.ToBytes() if err != nil { @@ -167,13 +167,13 @@ func TestIPv6UnknownOptionAction(t *testing.T) { // after the IPv6 header (after NextHeader and ExtHdrLen). binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2) icmpv6Payload = append(icmpv6Payload, invokingPacket...) - gotICMPv6, err := ipv6Conn.ExpectFrame(tb.Layers{ - &tb.Ether{}, - &tb.IPv6{}, - &tb.ICMPv6{ - Type: tb.ICMPv6Type(header.ICMPv6ParamProblem), - Code: tb.Byte(2), - NDPPayload: icmpv6Payload, + gotICMPv6, err := ipv6Conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Code: testbench.ICMPv6Code(header.ICMPv6UnknownOption), + Payload: icmpv6Payload, }, }, time.Second) if tt.wantICMPv6 && err != nil { diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go deleted file mode 100644 index 6e7ff41d7..000000000 --- a/test/packetimpact/tests/tcp_close_wait_ack_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_close_wait_ack_test - -import ( - "flag" - "fmt" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.RegisterFlags(flag.CommandLine) -} - -func TestCloseWaitAck(t *testing.T) { - for _, tt := range []struct { - description string - makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP - seqNumOffset seqnum.Size - expectAck bool - }{ - {"OTW", GenerateOTWSeqSegment, 0, false}, - {"OTW", GenerateOTWSeqSegment, 1, true}, - {"OTW", GenerateOTWSeqSegment, 2, true}, - {"ACK", GenerateUnaccACKSegment, 0, false}, - {"ACK", GenerateUnaccACKSegment, 1, true}, - {"ACK", GenerateUnaccACKSegment, 2, true}, - } { - t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) - conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - - // Send a FIN to DUT to intiate the active close - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) - gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) - } - windowSize := seqnum.Size(*gotTCP.WindowSize) - - // Send a segment with OTW Seq / unacc ACK and expect an ACK back - conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) - gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) - if tt.expectAck && err != nil { - t.Fatalf("expected an ack but got none: %s", err) - } - if !tt.expectAck && gotAck != nil { - t.Fatalf("expected no ack but got one: %s", gotAck) - } - - // Now let's verify DUT is indeed in CLOSE_WAIT - dut.Close(acceptFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { - t.Fatalf("expected DUT to send a FIN: %s", err) - } - // Ack the FIN from DUT - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - // Send some extra data to DUT - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { - t.Fatalf("expected DUT to send an RST: %s", err) - } - }) - } -} - -// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the -// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK -// is expected from the receiver. -func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.LocalSeqNum().Add(windowSize) - otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} -} - -// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated -// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is -// expected from the receiver. -func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.RemoteSeqNum() - unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} -} diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go index fb8f48629..8feea4a82 100644 --- a/test/packetimpact/tests/tcp_cork_mss_test.go +++ b/test/packetimpact/tests/tcp_cork_mss_test.go @@ -32,53 +32,53 @@ func init() { func TestTCPCorkMSS(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) const mss = uint32(header.TCPDefaultMSS) options := make([]byte, header.TCPOptionMSSLength) header.EncodeMSSOption(mss, options) - conn.ConnectWithOptions(options) + conn.ConnectWithOptions(t, options) - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) - dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) + dut.SetSockOptInt(t, acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) // Let the dut application send 2 small segments to be held up and coalesced // until the application sends a larger segment to fill up to > MSS. sampleData := []byte("Sample Data") - dut.Send(acceptFD, sampleData, 0) - dut.Send(acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) expectedData := sampleData expectedData = append(expectedData, sampleData...) largeData := make([]byte, mss+1) expectedData = append(expectedData, largeData...) - dut.Send(acceptFD, largeData, 0) + dut.Send(t, acceptFD, largeData, 0) // Expect the segments to be coalesced and sent and capped to MSS. expectedPayload := testbench.Payload{Bytes: expectedData[:mss]} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the coalesced segment to be split and transmitted. expectedPayload = testbench.Payload{Bytes: expectedData[mss:]} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Check for segments to *not* be held up because of TCP_CORK when // the current send window is less than MSS. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) - dut.Send(acceptFD, sampleData, 0) - dut.Send(acceptFD, sampleData, 0) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) } diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go index 652b530d0..22937d92f 100644 --- a/test/packetimpact/tests/tcp_handshake_window_size_test.go +++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go @@ -33,14 +33,14 @@ func init() { func TestTCPHandshakeWindowSize(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Start handshake with zero window size. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK: %s", err) } // Update the advertised window size to a non-zero value with the ACK that @@ -48,10 +48,10 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // // Set the window size with MSB set and expect the dut to treat it as // an unsigned value. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) - acceptFd, _ := dut.Accept(listenFD) - defer dut.Close(acceptFd) + acceptFd, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFd) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} @@ -59,8 +59,8 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // Since we advertised a zero window followed by a non-zero window, // expect the dut to honor the recently advertised non-zero window // and actually send out the data instead of probing for zero window. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go new file mode 100644 index 000000000..b9a0409aa --- /dev/null +++ b/test/packetimpact/tests/tcp_linger_test.go @@ -0,0 +1,270 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_linger_test + +import ( + "context" + "flag" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func createSocket(t *testing.T, dut testbench.DUT) (int32, int32, testbench.TCPIPv4) { + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + return acceptFD, listenFD, conn +} + +func closeAll(t *testing.T, dut testbench.DUT, listenFD int32, conn testbench.TCPIPv4) { + conn.Close(t) + dut.Close(t, listenFD) + dut.TearDown() +} + +// lingerDuration is the timeout value used with SO_LINGER socket option. +const lingerDuration = 3 * time.Second + +// TestTCPLingerZeroTimeout tests when SO_LINGER is set with zero timeout. DUT +// should send RST-ACK when socket is closed. +func TestTCPLingerZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Close(t, acceptFD) + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerOff tests when SO_LINGER is not set. DUT should send FIN-ACK +// when socket is closed. +func TestTCPLingerOff(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.Close(t, acceptFD) + + // If SO_LINGER is not set, DUT should send a FIN-ACK. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerNonZeroTimeout tests when SO_LINGER is set with non-zero timeout. +// DUT should close the socket after timeout. +func TestTCPLingerNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerSendNonZeroTimeout tests when SO_LINGER is set with non-zero +// timeout and send a packet. DUT should close the socket after timeout. +func TestTCPLingerSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithSendNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerShutdownZeroTimeout tests SO_LINGER with shutdown() and zero +// timeout. DUT should send RST-ACK when socket is closed. +func TestTCPLingerShutdownZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + dut.Close(t, acceptFD) + + // Shutdown will send FIN-ACK with read/write option. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerShutdownSendNonZeroTimeout tests SO_LINGER with shutdown() and +// non-zero timeout. DUT should close the socket after timeout. +func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"shutdownRDWR", true}, + {"shutdownRDWR", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +func TestTCPLingerNonEstablished(t *testing.T) { + dut := testbench.NewDUT(t) + newFD := dut.Socket(t, unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) + dut.SetSockLingerOption(t, newFD, lingerDuration, true) + + // As the socket is in the initial state, Close() should not linger + // and return immediately. + start := time.Now() + dut.CloseWithErrno(context.Background(), t, newFD) + diff := time.Since(start) + + if diff > lingerDuration { + t.Errorf("expected close to return within %s, but returned after %s", lingerDuration, diff) + } + dut.TearDown() +} diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go new file mode 100644 index 000000000..2f57dff19 --- /dev/null +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -0,0 +1,141 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_synsent_reset_test + +import ( + "context" + "flag" + "net" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPSynSentUnreachable verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + port := uint16(9001) + conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) + defer conn.Close(t) + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4()) + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)} + layers = append(layers, &icmpv4, ip, tcp) + rawConn.SendFrameStateless(t, layers) + + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) { + t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err) + } +} + +// TestTCPSynSentUnreachable6 verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable6(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6)) + conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort}) + defer conn.Close(t) + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet6{ + Port: int(conn.SrcPort()), + ZoneId: uint32(testbench.RemoteInterfaceID), + } + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16()) + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv6) + if !ok { + t.Fatalf("expected %s to be IPv6", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv6 testbench.ICMPv6 = testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6DstUnreachable), + Code: testbench.ICMPv6Code(header.ICMPv6NetworkUnreachable), + // Per RFC 4443 3.1, the payload contains 4 zeroed bytes. + Payload: []byte{0, 0, 0, 0}, + } + layers = append(layers, &icmpv6, ip, tcp) + rawConn.SendFrameStateless(t, layers) + + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) { + t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err) + } +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index b9b3e91d3..82b7a85ff 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -31,12 +31,12 @@ func init() { func TestTcpNoAcceptCloseReset(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - conn.Connect() - defer conn.Close() - dut.Close(listenFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + conn.Connect(t) + defer conn.Close(t) + dut.Close(t, listenFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { t.Fatalf("expected a RST-ACK packet but got none: %s", err) } } diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index ad8c74234..08f759f7c 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -63,25 +63,25 @@ func TestTCPOutsideTheWindow(t *testing.T) { t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) - windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset - conn.Drain() + windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + tt.seqNumOffset + conn.Drain(t) // Ignore whatever incrementing that this out-of-order packet might cause // to the AckNum. - localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum())) - conn.Send(testbench.TCP{ + localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{ Flags: testbench.Uint8(tt.tcpFlags), - SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), }, tt.payload...) timeout := 3 * time.Second - gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) if tt.expectACK && err != nil { t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) } diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go index 55db4ece6..37f3b56dd 100644 --- a/test/packetimpact/tests/tcp_paws_mechanism_test.go +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -32,15 +32,15 @@ func init() { func TestPAWSMechanism(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) options := make([]byte, header.TCPOptionTSLength) header.EncodeTSOption(currentTS(), 0, options) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) - synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("didn't get synack during handshake: %s", err) } @@ -50,9 +50,9 @@ func TestPAWSMechanism(t *testing.T) { } tsecr := parsedSynOpts.TSVal header.EncodeTSOption(currentTS(), tsecr, options) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) sampleData := []byte("Sample Data") sentTSVal := currentTS() @@ -61,9 +61,9 @@ func TestPAWSMechanism(t *testing.T) { // every time we send one, it should not cause any flakiness because timestamps // only need to be non-decreasing. time.Sleep(3 * time.Millisecond) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK but got none: %s", err) } @@ -86,9 +86,9 @@ func TestPAWSMechanism(t *testing.T) { // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness // due to the exact same reasoning discussed above. time.Sleep(3 * time.Millisecond) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) } diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go index 8fbec893b..d9f3ea0f2 100644 --- a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go +++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go @@ -52,26 +52,26 @@ func TestQueueReceiveInSynSent(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) sampleData := []byte("Sample Data") - dut.SetNonBlocking(socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) { + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) } - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { t.Fatalf("expected a SYN from DUT, but got none: %s", err) } - if _, _, err := dut.RecvWithErrno(context.Background(), socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { + if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) } // Test blocking read. - dut.SetNonBlocking(socket, false) + dut.SetNonBlocking(t, socket, false) var wg sync.WaitGroup defer wg.Wait() @@ -86,7 +86,7 @@ func TestQueueReceiveInSynSent(t *testing.T) { block.Done() // Issue RECEIVE call in SYN-SENT, this should be queued for // process until the connection is established. - n, buff, err := dut.RecvWithErrno(ctx, socket, int32(len(sampleData)), 0) + n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) if tt.reset { if err != syscall.Errno(unix.ECONNREFUSED) { t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) @@ -112,19 +112,19 @@ func TestQueueReceiveInSynSent(t *testing.T) { time.Sleep(100 * time.Millisecond) if tt.reset { - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) return } // Bring the connection to Established. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } // Send sample payload and expect an ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } }) diff --git a/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go new file mode 100644 index 000000000..0ec8fd748 --- /dev/null +++ b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_queue_send_in_syn_sent_test + +import ( + "context" + "errors" + "flag" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestQueueSendInSynSent tests send behavior when the TCP state +// is SYN-SENT. +// It tests for 2 variants when in SYN_SENT state and: +// (1) DUT blocks on send and complete handshake +// (2) DUT blocks on send and receive a TCP RST. +func TestQueueSendInSynSent(t *testing.T) { + for _, tt := range []struct { + description string + reset bool + }{ + {description: "Complete handshake", reset: false}, + {description: "Send RST", reset: true}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + t.Fatalf("expected a SYN from DUT, but got none: %s", err) + } + if _, err := dut.SendWithErrno(context.Background(), t, socket, sampleData, 0); err != syscall.Errno(unix.EWOULDBLOCK) { + t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + } + + // Test blocking write. + dut.SetNonBlocking(t, socket, false) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + // Issue SEND call in SYN-SENT, this should be queued for + // process until the connection is established. + n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0) + if tt.reset { + if err != syscall.Errno(unix.ECONNREFUSED) { + t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + } + if n != -1 { + t.Errorf("expected return value %d, got %d", -1, n) + } + return + } + if n != int32(len(sampleData)) { + t.Errorf("failed to send on DUT: %s", err) + } + }() + + // Wait for the goroutine to be scheduled and before it + // blocks on endpoint send. + block.Wait() + // The following sleep is used to prevent the connection + // from being established before we are blocked on send. + time.Sleep(100 * time.Millisecond) + + if tt.reset { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + return + } + + // Bring the connection to Established. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + + // Expect the data from the DUT's enqueued send request. + // + // On Linux, this can be piggybacked with the ACK completing the + // handshake. On gVisor, getting such a piggyback is a bit more + // complicated because the actual data enqueuing occurs in the + // callers of endpoint Write. + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Send sample payload and expect an ACK to ensure connection is still ESTABLISHED. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_rcv_buf_space_test.go b/test/packetimpact/tests/tcp_rcv_buf_space_test.go new file mode 100644 index 000000000..cfbba1e8e --- /dev/null +++ b/test/packetimpact/tests/tcp_rcv_buf_space_test.go @@ -0,0 +1,80 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_rcv_buf_space_test + +import ( + "context" + "flag" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestReduceRecvBuf tests that a packet within window is still dropped +// if the available buffer space drops below the size of the incoming +// segment. +func TestReduceRecvBuf(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + // Set a small receive buffer for the test. + const rcvBufSz = 4096 + dut.SetSockOptInt(t, acceptFd, unix.SOL_SOCKET, unix.SO_RCVBUF, rcvBufSz) + + // Retrieve the actual buffer. + bufSz := dut.GetSockOptInt(t, acceptFd, unix.SOL_SOCKET, unix.SO_RCVBUF) + + // Generate a payload of 1 more than the actual buffer size used by the + // DUT. + sampleData := testbench.GenerateRandomPayload(t, int(bufSz)+1) + // Send and receive sample data to the dut. + const pktSize = 1400 + for payload := sampleData; len(payload) != 0; { + payloadBytes := pktSize + if l := len(payload); l < payloadBytes { + payloadBytes = l + } + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...) + payload = payload[payloadBytes:] + } + + // First read should read < len(sampleData) + if ret, _, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), 0); ret == -1 || int(ret) == len(sampleData) { + t.Fatalf("dut.RecvWithErrno(ctx, t, %d, %d, 0) = %d,_, %s", acceptFd, int32(len(sampleData)), ret, err) + } + + // Second read should return EAGAIN as the last segment should have been + // dropped due to it exceeding the receive buffer space available in the + // socket. + if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), syscall.MSG_DONTWAIT); got != nil || ret != -1 || err != syscall.EAGAIN { + t.Fatalf("expected no packets but got: %s", got) + } +} diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go index a5378a9dd..b4aeaab57 100644 --- a/test/packetimpact/tests/tcp_reordering_test.go +++ b/test/packetimpact/tests/tcp_reordering_test.go @@ -32,10 +32,10 @@ func init() { func TestReorderingWindow(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Enable SACK. opts := make([]byte, 40) @@ -49,18 +49,18 @@ func TestReorderingWindow(t *testing.T) { const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize optsOff += header.EncodeMSSOption(mss, opts[optsOff:]) - conn.ConnectWithOptions(opts[:optsOff]) + conn.ConnectWithOptions(t, opts[:optsOff]) - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - if tb.DUTType == "linux" { + if tb.Native { // Linux has changed its handling of reordering, force the old behavior. - dut.SetSockOpt(acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) + dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) } - pls := dut.GetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) - if tb.DUTType == "netstack" { + pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) + if !tb.Native { // netstack does not impliment TCP_MAXSEG correctly. Fake it // here. Netstack uses the max SACK size which is 32. The MSS // option is 8 bytes, making the total 36 bytes. @@ -69,13 +69,13 @@ func TestReorderingWindow(t *testing.T) { payload := make([]byte, pls) - seqNum1 := *conn.RemoteSeqNum() + seqNum1 := *conn.RemoteSeqNum(t) const numPkts = 10 // Send some packets, checking that we receive each. for i, sn := 0, seqNum1; i < numPkts; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -86,7 +86,7 @@ func TestReorderingWindow(t *testing.T) { } } - seqNum2 := *conn.RemoteSeqNum() + seqNum2 := *conn.RemoteSeqNum(t) // SACK packets #2-4. sackBlock := make([]byte, 40) @@ -97,13 +97,13 @@ func TestReorderingWindow(t *testing.T) { seqNum1.Add(seqnum.Size(len(payload))), seqNum1.Add(seqnum.Size(4 * len(payload))), }}, sackBlock[sbOff:]) - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) // ACK first packet. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) // Check for retransmit. - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) if err != nil { t.Error("Expect for retransmit:", err) } @@ -123,14 +123,14 @@ func TestReorderingWindow(t *testing.T) { seqNum1.Add(seqnum.Size(4 * len(payload))), }}, dsackBlock[dsbOff:]) - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) // Send half of the original window of packets, checking that we // received each. for i, sn := 0, seqNum2; i < numPkts/2; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -141,11 +141,11 @@ func TestReorderingWindow(t *testing.T) { } } - if tb.DUTType == "netstack" { + if !tb.Native { // The window should now be halved, so we should receive any // more, even if we send them. - dut.Send(acceptFd, payload, 0) - if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) } return @@ -153,9 +153,9 @@ func TestReorderingWindow(t *testing.T) { // Linux reduces the window by three. Check that we can receive the rest. for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -167,8 +167,8 @@ func TestReorderingWindow(t *testing.T) { } // The window should now be full. - dut.Send(acceptFd, payload, 0) - if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) } } diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go index 6940eb7fb..072014ff8 100644 --- a/test/packetimpact/tests/tcp_retransmits_test.go +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -33,41 +33,41 @@ func init() { func TestRetransmits(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) startRTO := time.Second current := startRTO first := time.Now() - dut.Send(acceptFd, sampleData, 0) - seq := testbench.Uint32(uint32(*conn.RemoteSeqNum())) - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Expect retransmits of the same segment. for i := 0; i < 5; i++ { start := time.Now() - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { t.Fatalf("expected payload was not received: %s loop %d", err, i) } if i == 0 { diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go index 90ab85419..f91b06ba1 100644 --- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go +++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go @@ -61,23 +61,23 @@ func TestSendWindowSizesPiggyback(t *testing.T) { t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1} - if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } @@ -86,18 +86,18 @@ func TestSendWindowSizesPiggyback(t *testing.T) { if tt.enqueue { // Enqueue a segment for the dut to transmit. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) } // Send ACK for the previous segment along with data for the dut to // receive and ACK back. Sending this ACK would make room for the dut // to transmit any enqueued segment. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) // Expect the dut to piggyback the ACK for received data along with // the segment enqueued for transmit. expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2} - if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } }) diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go index 7d5deab01..57d034dd1 100644 --- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go +++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go @@ -32,21 +32,21 @@ func init() { func TestTCPSynRcvdReset(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Expect dut connection to have transitioned to SYN-RCVD state. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST %s", err) } } diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go index 6898a2239..eac8eb19d 100644 --- a/test/packetimpact/tests/tcp_synsent_reset_test.go +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -31,17 +31,19 @@ func init() { // dutSynSentState sets up the dut connection in SYN-SENT state. func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { + t.Helper() + dut := tb.NewDUT(t) - clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) port := uint16(9001) conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port}) sa := unix.SockaddrInet4{Port: int(port)} copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4()) // Bring the dut to SYN-SENT state with a non-blocking connect. - dut.Connect(clientFD, &sa) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { + dut.Connect(t, clientFD, &sa) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { t.Fatalf("expected SYN\n") } @@ -51,13 +53,13 @@ func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { // TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. func TestTCPSynSentReset(t *testing.T) { dut, conn, _, _ := dutSynSentState(t) - defer conn.Close() + defer conn.Close(t) defer dut.TearDown() - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) // Expect the connection to have closed. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } @@ -67,22 +69,22 @@ func TestTCPSynSentReset(t *testing.T) { func TestTCPSynSentRcvdReset(t *testing.T) { dut, c, remotePort, clientPort := dutSynSentState(t) defer dut.TearDown() - defer c.Close() + defer c.Close(t) conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Initiate new SYN connection with the same port pair // (simultaneous open case), expect the dut connection to move to // SYN-RCVD state - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s\n", err) } - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go new file mode 100644 index 000000000..2f76a6531 --- /dev/null +++ b/test/packetimpact/tests/tcp_timewait_reset_test.go @@ -0,0 +1,68 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_timewait_reset_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTimeWaitReset tests handling of RST when in TIME_WAIT state. +func TestTimeWaitReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Close(t, acceptFD) + + _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send a FIN, DUT should transition to TIME_WAIT from FIN_WAIT2. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our FIN: %s", err) + } + + // Send a RST, the DUT should transition to CLOSED from TIME_WAIT. + // This is the default Linux behavior, it can be changed to ignore RSTs via + // sysctl net.ipv4.tcp_rfc1337. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // The DUT should reply with RST to our ACK as the state should have + // transitioned to CLOSED. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go new file mode 100644 index 000000000..d078bbf15 --- /dev/null +++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go @@ -0,0 +1,234 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_unacc_seq_ack_test + +import ( + "flag" + "fmt" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestEstablishedUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + restoreSeq bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + dut.Accept(t, listenFD) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected ack %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + origSeq := *conn.LocalSeqNum(t) + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) + if tt.restoreSeq { + // Restore the local sequence number to ensure that the incoming + // ACK matches the TCP layer state. + *conn.LocalSeqNum(t) = origSeq + } + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if err == nil && !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + }) + } +} + +func TestPassiveCloseUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: false}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Send a FIN to DUT to intiate the passive close. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Errorf("expected an ack but got none: %s", err) + } + if err == nil && !tt.expectAck && gotAck != nil { + t.Errorf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(t, acceptFD) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, samplePayload) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +func TestActiveCloseUnaccpSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + restoreSeq bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Shutdown(t, acceptFD, syscall.SHUT_WR) + + // Get to FIN_WAIT2 + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + sendUnaccSeqAck := func(state string) { + t.Helper() + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + origSeq := *conn.LocalSeqNum(t) + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)), samplePayload) + if tt.restoreSeq { + // Restore the local sequence number to ensure that the + // incoming ACK matches the TCP layer state. + *conn.LocalSeqNum(t) = origSeq + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ack in %s state, but got none: %s", state, err) + } + } + + sendUnaccSeqAck("FIN_WAIT2") + + // Send a FIN to DUT to get to TIME_WAIT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter TIME_WAIT: %s", err) + } + + sendUnaccSeqAck("TIME_WAIT") + }) + } +} + +// generateOTWSeqSegment generates an segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} +} + +// generateUnaccACKSegment generates an segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum(t) + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go index 87e45d765..551dc78e7 100644 --- a/test/packetimpact/tests/tcp_user_timeout_test.go +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -16,7 +16,6 @@ package tcp_user_timeout_test import ( "flag" - "fmt" "testing" "time" @@ -29,22 +28,20 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { +func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { sampleData := make([]byte, 100) for i := range sampleData { sampleData[i] = uint8(i) } - conn.Drain() - dut.Send(fd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { - return fmt.Errorf("expected data but got none: %w", err) + conn.Drain(t) + dut.Send(t, fd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + t.Fatalf("expected data but got none: %w", err) } - return nil } -func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { - dut.Close(fd) - return nil +func sendFIN(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { + dut.Close(t, fd) } func TestTCPUserTimeout(t *testing.T) { @@ -59,7 +56,7 @@ func TestTCPUserTimeout(t *testing.T) { } { for _, ttf := range []struct { description string - f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error + f func(_ *testing.T, _ *testbench.TCPIPv4, _ *testbench.DUT, fd int32) }{ {"AfterPayload", sendPayload}, {"AfterFIN", sendFIN}, @@ -68,31 +65,29 @@ func TestTCPUserTimeout(t *testing.T) { // Create a socket, listen, TCP handshake, and accept. dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() - acceptFD, _ := dut.Accept(listenFD) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) if tt.userTimeout != 0 { - dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + dut.SetSockOptInt(t, acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) } - if err := ttf.f(&conn, &dut, acceptFD); err != nil { - t.Fatal(err) - } + ttf.f(t, &conn, &dut, acceptFD) time.Sleep(tt.sendDelay) - conn.Drain() - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Drain(t) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // If TCP_USER_TIMEOUT was set and the above delay was longer than the // TCP_USER_TIMEOUT then the DUT should send a RST in response to the // testbench's packet. expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout expectTimeout := 5 * time.Second - got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) + got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) if expectRST && err != nil { t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) } diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index e78d04756..5b001fbec 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -31,43 +31,43 @@ func init() { func TestWindowShrink(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - dut.Send(acceptFd, sampleData, 0) - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // We close our receiving window here - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - dut.Send(acceptFd, []byte("Sample Data"), 0) + dut.Send(t, acceptFd, []byte("Sample Data"), 0) // Note: There is another kind of zero-window probing which Windows uses (by sending one // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change // the following lines. - expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + expectedRemoteSeqNum := *conn.RemoteSeqNum(t) - 1 + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index 8c89d57c9..da93267d6 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -33,27 +33,27 @@ func init() { func TestZeroWindowProbeRetransmit(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -63,15 +63,15 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { // of the recorded first zero probe transmission duration. // // Advertize zero receive window again. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) - ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) startProbeDuration := time.Second current := startProbeDuration first := time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect the dut to keep the connection alive as long as the remote is // acknowledging the zero-window probes. for i := 0; i < 5; i++ { @@ -79,7 +79,7 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { // Expect zero-window probe with a timeout which is a function of the typical // first retransmission time. The retransmission times is supposed to // exponentially increase. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) } if i == 0 { @@ -92,14 +92,13 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) } // Acknowledge the zero-window probes from the dut. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) current *= 2 } // Advertize non-zero window. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. - if _, err := conn.ExpectData(&testbench. - TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go index 649fd5699..44cac42f8 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -33,29 +33,29 @@ func init() { func TestZeroWindowProbe(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} start := time.Now() // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } sendTime := time.Now().Sub(start) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -63,24 +63,24 @@ func TestZeroWindowProbe(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) // Expected ack number of the ACK for the probe. - ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) // Expect there are no zero-window probes sent until there is data to be sent out // from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err) } start = time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect zero-window probe from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) } // Expect the probe to be sent after some time. Compare against the previous @@ -94,9 +94,9 @@ func TestZeroWindowProbe(t *testing.T) { // and sends out the sample payload after the send window opens. // // Advertize non-zero window to the dut and ack the zero window probe. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } @@ -104,9 +104,9 @@ func TestZeroWindowProbe(t *testing.T) { // Check if the dut responds as we do for a similar probe sent to it. // Basically with sequence number to one byte behind the unacknowledged // sequence number. - p := testbench.Uint32(uint32(*conn.LocalSeqNum())) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { + p := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { t.Fatalf("expected a packet with ack number: %d: %s", p, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go index 3c467b14f..09a1c653f 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -33,27 +33,27 @@ func init() { func TestZeroWindowProbeUserTimeout(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -61,15 +61,15 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) start := time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect zero-window probe from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) } // Record the duration for first probe, the dut sends the zero window probe after @@ -80,19 +80,19 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // when the dut is sending zero-window probes. // // Reduce the retransmit timeout. - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) // Advertize zero window again. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Ask the dut to send out data that would trigger zero window probe retransmissions. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Wait for the connection to timeout after multiple zero-window probe retransmissions. time.Sleep(8 * startProbeDuration) // Expect the connection to have timed out and closed which would cause the dut // to reply with a RST to the ACK we send. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go index 77a9bfa1d..17f32ef65 100644 --- a/test/packetimpact/tests/udp_recv_multicast_test.go +++ b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package udp_recv_multicast_test +package udp_any_addr_recv_unicast_test import ( "flag" "net" "testing" + "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/test/packetimpact/testbench" @@ -28,13 +29,23 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func TestUDPRecvMulticast(t *testing.T) { +func TestAnyRecvUnicastUDP(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(boundFD) + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, boundFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() - conn.SendIP(testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))}, testbench.UDP{}) - dut.Recv(boundFD, 100, 0) + defer conn.Close(t) + + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP(testbench.RemoteIPv4).To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } } diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go new file mode 100644 index 000000000..3d2791a6e --- /dev/null +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -0,0 +1,96 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_discard_mcast_source_addr_test + +import ( + "context" + "flag" + "fmt" + "net" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +var oneSecond = unix.Timeval{Sec: 1, Usec: 0} + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, mcastAddr := range []net.IP{ + net.IPv4allsys, + net.IPv4allrouter, + net.IPv4(224, 0, 1, 42), + net.IPv4(232, 1, 2, 3), + } { + t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { + conn.SendIP( + t, + testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: []byte("test payload")}, + ) + + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} + +func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, mcastAddr := range []net.IP{ + net.IPv6interfacelocalallnodes, + net.IPv6linklocalallnodes, + net.IPv6linklocalallrouters, + net.ParseIP("ff01::42"), + net.ParseIP("ff02::4242"), + } { + t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { + conn.SendIPv6( + t, + testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: []byte("test payload")}, + ) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index b754918f6..df35d16c8 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -62,9 +62,13 @@ func (e icmpError) String() string { func (e icmpError) ToICMPv4() *testbench.ICMPv4 { switch e { case portUnreachable: - return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4PortUnreachable)} + return &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4PortUnreachable)} case timeToLiveExceeded: - return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), Code: testbench.Uint8(header.ICMPv4TTLExceeded)} + return &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), + Code: testbench.ICMPv4Code(header.ICMPv4TTLExceeded)} } return nil } @@ -72,7 +76,7 @@ func (e icmpError) ToICMPv4() *testbench.ICMPv4 { type errorDetection struct { name string useValidConn bool - f func(context.Context, testData) error + f func(context.Context, *testing.T, testData) } type testData struct { @@ -95,12 +99,14 @@ func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { } // sendICMPError sends an ICMP error message in response to a UDP datagram. -func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error { - layers := (*testbench.Connection)(conn).CreateFrame(nil) +func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) { + t.Helper() + + layers := (*testbench.Connection)(conn).CreateFrame(t, nil) layers = layers[:len(layers)-1] ip, ok := udp.Prev().(*testbench.IPv4) if !ok { - return fmt.Errorf("expected %s to be IPv4", udp.Prev()) + t.Fatalf("expected %s to be IPv4", udp.Prev()) } if icmpErr == timeToLiveExceeded { *ip.TTL = 1 @@ -114,84 +120,82 @@ func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UD // resulting in a mal-formed packet. layers = append(layers, icmpErr.ToICMPv4(), ip, udp) - (*testbench.Connection)(conn).SendFrameStateless(layers) - return nil + (*testbench.Connection)(conn).SendFrameStateless(t, layers) } // testRecv tests observing the ICMP error through the recv syscall. A packet // is sent to the DUT, and if wantErrno is non-zero, then the first recv should // fail and the second should succeed. Otherwise if wantErrno is zero then the // first recv should succeed immediately. -func testRecv(ctx context.Context, d testData) error { +func testRecv(ctx context.Context, t *testing.T, d testData) { + t.Helper() + // Check that receiving on the clean socket works. - d.conn.Send(testbench.UDP{DstPort: &d.cleanPort}) - d.dut.Recv(d.cleanFD, 100, 0) + d.conn.Send(t, testbench.UDP{DstPort: &d.cleanPort}) + d.dut.Recv(t, d.cleanFD, 100, 0) - d.conn.Send(testbench.UDP{}) + d.conn.Send(t, testbench.UDP{}) if d.wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - ret, _, err := d.dut.RecvWithErrno(ctx, d.remoteFD, 100, 0) + ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0) if ret != -1 { - return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) } if err != d.wantErrno { - return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + t.Fatalf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) } } - d.dut.Recv(d.remoteFD, 100, 0) - return nil + d.dut.Recv(t, d.remoteFD, 100, 0) } // testSendTo tests observing the ICMP error through the send syscall. If // wantErrno is non-zero, the first send should fail and a subsequent send // should suceed; while if wantErrno is zero then the first send should just // succeed. -func testSendTo(ctx context.Context, d testData) error { +func testSendTo(ctx context.Context, t *testing.T, d testData) { // Check that sending on the clean socket works. - d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err) + d.dut.SendTo(t, d.cleanFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err) } if d.wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - ret, err := d.dut.SendToWithErrno(ctx, d.remoteFD, nil, 0, d.conn.LocalAddr()) + ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) if ret != -1 { - return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + t.Fatalf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) } if err != d.wantErrno { - return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + t.Fatalf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) } } - d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet as expected: %s", err) + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) } - return nil } -func testSockOpt(_ context.Context, d testData) error { +func testSockOpt(_ context.Context, t *testing.T, d testData) { // Check that there's no pending error on the clean socket. - if errno := syscall.Errno(d.dut.GetSockOptInt(d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { - return fmt.Errorf("unexpected error (%[1]d) %[1]v on clean socket", errno) + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { + t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno) } - if errno := syscall.Errno(d.dut.GetSockOptInt(d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { - return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { + t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) } // Check that after clearing socket error, sending doesn't fail. - d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet as expected: %s", err) + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) } - return nil } // TestUDPICMPErrorPropagation tests that ICMP error messages in response to @@ -227,31 +231,29 @@ func TestUDPICMPErrorPropagation(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(remoteFD) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, remoteFD) // Create a second, clean socket on the DUT to ensure that the ICMP // error messages only affect the sockets they are intended for. - cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(cleanFD) + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, cleanFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) if connect { - dut.Connect(remoteFD, conn.LocalAddr()) - dut.Connect(cleanFD, conn.LocalAddr()) + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) } - dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) - udp, err := conn.Expect(testbench.UDP{}, time.Second) + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) if err != nil { t.Fatalf("did not receive message from DUT: %s", err) } - if err := sendICMPError(&conn, icmpErr, udp); err != nil { - t.Fatal(err) - } + sendICMPError(t, &conn, icmpErr, udp) errDetectConn := &conn if errDetect.useValidConn { @@ -260,14 +262,12 @@ func TestUDPICMPErrorPropagation(t *testing.T) { // interactions between it and the the DUT should be independent of // the ICMP error at least at the port level. connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer connClean.Close() + defer connClean.Close(t) errDetectConn = &connClean } - if err := errDetect.f(context.Background(), testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}); err != nil { - t.Fatal(err) - } + errDetect.f(context.Background(), t, testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}) }) } } @@ -285,24 +285,24 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(remoteFD) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, remoteFD) // Create a second, clean socket on the DUT to ensure that the ICMP // error messages only affect the sockets they are intended for. - cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(cleanFD) + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, cleanFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) if connect { - dut.Connect(remoteFD, conn.LocalAddr()) - dut.Connect(cleanFD, conn.LocalAddr()) + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) } - dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) - udp, err := conn.Expect(testbench.UDP{}, time.Second) + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) if err != nil { t.Fatalf("did not receive message from DUT: %s", err) } @@ -316,7 +316,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0) + ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0) if ret != -1 { t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) return @@ -330,7 +330,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 { t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) } }() @@ -341,7 +341,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 { t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) } }() @@ -352,12 +352,10 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { // alternative is available. time.Sleep(2 * time.Second) - if err := sendICMPError(&conn, icmpErr, udp); err != nil { - t.Fatal(err) - } + sendICMPError(t, &conn, icmpErr, udp) - conn.Send(testbench.UDP{DstPort: &cleanPort}) - conn.Send(testbench.UDP{}) + conn.Send(t, testbench.UDP{DstPort: &cleanPort}) + conn.Send(t, testbench.UDP{}) wg.Wait() }) } diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go new file mode 100644 index 000000000..526173969 --- /dev/null +++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_mcast_bcast_test + +import ( + "context" + "flag" + "fmt" + "net" + "syscall" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestUDPRecvMcastBcast(t *testing.T) { + subnetBcastAddr := broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)) + + for _, v := range []struct { + bound, to net.IP + }{ + {bound: net.IPv4zero, to: subnetBcastAddr}, + {bound: net.IPv4zero, to: net.IPv4bcast}, + {bound: net.IPv4zero, to: net.IPv4allsys}, + + {bound: subnetBcastAddr, to: subnetBcastAddr}, + {bound: subnetBcastAddr, to: net.IPv4bcast}, + + {bound: net.IPv4bcast, to: net.IPv4bcast}, + {bound: net.IPv4allsys, to: net.IPv4allsys}, + } { + t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + }) + } +} + +func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0}) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, to := range []net.IP{ + broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)), + net.IPv4(255, 255, 255, 255), + net.IPv4(224, 0, 0, 1), + } { + t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) { + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} + +func broadcastAddr(ip net.IP, mask net.IPMask) net.IP { + ip4 := ip.To4() + for i := range ip4 { + ip4[i] |= ^mask[i] + } + return ip4 +} diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 224feef85..91b967400 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -28,62 +29,75 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func TestUDPRecv(t *testing.T) { +type udpConn interface { + Send(*testing.T, testbench.UDP, ...testbench.Layer) + ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) + Drain(*testing.T) + Close(*testing.T) +} + +func TestUDP(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(boundFD) - conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() - testCases := []struct { - name string - payload []byte - }{ - {"emptypayload", nil}, - {"small payload", []byte("hello world")}, - {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, - // Even though UDP allows larger dgrams we don't test it here as - // they need to be fragmented and written out as individual - // frames. - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) - if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want { - t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want) + for _, isIPv4 := range []bool{true, false} { + ipVersionName := "IPv6" + if isIPv4 { + ipVersionName = "IPv4" + } + t.Run(ipVersionName, func(t *testing.T) { + var addr string + if isIPv4 { + addr = testbench.RemoteIPv4 + } else { + addr = testbench.RemoteIPv6 } - }) - } -} + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr)) + defer dut.Close(t, boundFD) -func TestUDPSend(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(boundFD) - conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + var conn udpConn + var localAddr unix.Sockaddr + if isIPv4 { + v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + localAddr = v4Conn.LocalAddr(t) + conn = &v4Conn + } else { + v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + localAddr = v6Conn.LocalAddr(t) + conn = &v6Conn + } + defer conn.Close(t) - testCases := []struct { - name string - payload []byte - }{ - {"emptypayload", nil}, - {"small payload", []byte("hello world")}, - {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, - // Even though UDP allows larger dgrams we don't test it here as - // they need to be fragmented and written out as individual - // frames. - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - conn.Drain() - if got, want := int(dut.SendTo(boundFD, tc.payload, 0, conn.LocalAddr())), len(tc.payload); got != want { - t.Fatalf("short write got: %d, want: %d", got, want) + testCases := []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. } - if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, 1*time.Second); err != nil { - t.Fatal(err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Run("Send", func(t *testing.T) { + conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) + got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + }) + t.Run("Recv", func(t *testing.T) { + conn.Drain(t) + if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { + t.Fatalf("short write got: %d, want: %d", got, want) + } + if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { + t.Fatal(err) + } + }) + }) } }) } diff --git a/test/perf/BUILD b/test/perf/BUILD index 471d8c2ab..b763be50e 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -3,33 +3,40 @@ load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) syscall_test( + debug = False, test = "//test/perf/linux:clock_getres_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:clock_gettime_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:death_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:epoll_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:fork_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:futex_benchmark", ) syscall_test( size = "enormous", + debug = False, shard_count = 10, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", @@ -37,81 +44,96 @@ syscall_test( syscall_test( size = "large", + debug = False, test = "//test/perf/linux:getpid_benchmark", ) syscall_test( size = "enormous", + debug = False, tags = ["nogotsan"], test = "//test/perf/linux:gettid_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:mapping_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:open_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:pipe_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:randread_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:read_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:sched_yield_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:send_recv_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:seqwrite_benchmark", ) syscall_test( size = "enormous", + debug = False, test = "//test/perf/linux:signal_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:sleep_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:stat_benchmark", ) syscall_test( size = "enormous", add_overlay = True, + debug = False, test = "//test/perf/linux:unlink_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:write_benchmark", ) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD index b4e907826..dd1d2438c 100644 --- a/test/perf/linux/BUILD +++ b/test/perf/linux/BUILD @@ -354,3 +354,19 @@ cc_binary( "//test/util:test_util", ], ) + +cc_binary( + name = "open_read_close_benchmark", + testonly = 1, + srcs = [ + "open_read_close_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc index d8e81fa8c..9030eb356 100644 --- a/test/perf/linux/getdents_benchmark.cc +++ b/test/perf/linux/getdents_benchmark.cc @@ -105,7 +105,7 @@ void BM_GetdentsSameFD(benchmark::State& state) { state.SetItemsProcessed(state.iterations()); } -BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 12)->UseRealTime(); // Creates a directory containing `files` files, and reads all the directory // entries from the directory using a new FD each time. diff --git a/test/perf/linux/open_read_close_benchmark.cc b/test/perf/linux/open_read_close_benchmark.cc new file mode 100644 index 000000000..8b023a3d8 --- /dev/null +++ b/test/perf/linux/open_read_close_benchmark.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_OpenReadClose(benchmark::State& state) { + const int size = state.range(0); + std::vector<TempPath> cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "some content", 0644)); + cache.emplace_back(std::move(path)); + } + + char buf[1]; + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + TEST_CHECK(read(fd, buf, 1) == 1); + close(fd); + } +} + +// Gofer dentry cache is 1000 by default. Go over it to force files to be closed +// for real. +BENCHMARK(BM_OpenReadClose)->Range(1000, 16384)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/BUILD b/test/root/BUILD index a9e91ccd6..a9130b34f 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -41,7 +41,7 @@ go_test( "//runsc/container", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], @@ -51,8 +51,5 @@ vm_test( name = "root_vm_test", size = "large", shard_count = 1, - targets = [ - "//tools/installers:shim", - ":root_test", - ], + targets = [":root_test"], ) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index d0634b5c3..a26b83081 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -16,6 +16,7 @@ package root import ( "bufio" + "context" "fmt" "io/ioutil" "os" @@ -56,25 +57,24 @@ func verifyPid(pid int, path string) error { return fmt.Errorf("got: %v, want: %d", gots, pid) } -func TestMemCGroup(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() +func TestMemCgroup(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start a new container and allocate the specified about of memory. allocMemSize := 128 << 20 allocMemLimit := 2 * allocMemSize - if err := d.Spawn(dockerutil.RunOpts{ - Image: "basic/python", - Memory: allocMemLimit / 1024, // Must be in Kb. - }, "python", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil { + + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + Memory: allocMemLimit, // Must be in bytes. + }, "python3", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil { t.Fatalf("docker run failed: %v", err) } // Extract the ID to lookup the cgroup. - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) - } + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Wait when the container will allocate memory. @@ -127,8 +127,9 @@ func TestMemCGroup(t *testing.T) { // TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroup(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // This is not a comprehensive list of attributes. // @@ -137,94 +138,133 @@ func TestCgroup(t *testing.T) { // are often run on a single core virtual machine, and there is only a single // CPU available in our current set, and every container's set. attrs := []struct { - arg string + field string + value int64 ctrl string file string want string skipIfNotFound bool }{ { - arg: "--cpu-shares=1000", - ctrl: "cpu", - file: "cpu.shares", - want: "1000", + field: "cpu-shares", + value: 1000, + ctrl: "cpu", + file: "cpu.shares", + want: "1000", }, { - arg: "--cpu-period=2000", - ctrl: "cpu", - file: "cpu.cfs_period_us", - want: "2000", + field: "cpu-period", + value: 2000, + ctrl: "cpu", + file: "cpu.cfs_period_us", + want: "2000", }, { - arg: "--cpu-quota=3000", - ctrl: "cpu", - file: "cpu.cfs_quota_us", - want: "3000", + field: "cpu-quota", + value: 3000, + ctrl: "cpu", + file: "cpu.cfs_quota_us", + want: "3000", }, { - arg: "--kernel-memory=100MB", - ctrl: "memory", - file: "memory.kmem.limit_in_bytes", - want: "104857600", + field: "kernel-memory", + value: 100 << 20, + ctrl: "memory", + file: "memory.kmem.limit_in_bytes", + want: "104857600", }, { - arg: "--memory=1GB", - ctrl: "memory", - file: "memory.limit_in_bytes", - want: "1073741824", + field: "memory", + value: 1 << 30, + ctrl: "memory", + file: "memory.limit_in_bytes", + want: "1073741824", }, { - arg: "--memory-reservation=500MB", - ctrl: "memory", - file: "memory.soft_limit_in_bytes", - want: "524288000", + field: "memory-reservation", + value: 500 << 20, + ctrl: "memory", + file: "memory.soft_limit_in_bytes", + want: "524288000", }, { - arg: "--memory-swap=2GB", + field: "memory-swap", + value: 2 << 30, ctrl: "memory", file: "memory.memsw.limit_in_bytes", want: "2147483648", skipIfNotFound: true, // swap may be disabled on the machine. }, { - arg: "--memory-swappiness=5", - ctrl: "memory", - file: "memory.swappiness", - want: "5", + field: "memory-swappiness", + value: 5, + ctrl: "memory", + file: "memory.swappiness", + want: "5", }, { - arg: "--blkio-weight=750", + field: "blkio-weight", + value: 750, ctrl: "blkio", file: "blkio.weight", want: "750", skipIfNotFound: true, // blkio groups may not be available. }, { - arg: "--pids-limit=1000", - ctrl: "pids", - file: "pids.max", - want: "1000", + field: "pids-limit", + value: 1000, + ctrl: "pids", + file: "pids.max", + want: "1000", }, } - args := make([]string, 0, len(attrs)) + // Make configs. + conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000") + + // Add Cgroup arguments to configs. for _, attr := range attrs { - args = append(args, attr.arg) + switch attr.field { + case "cpu-shares": + hostconf.Resources.CPUShares = attr.value + case "cpu-period": + hostconf.Resources.CPUPeriod = attr.value + case "cpu-quota": + hostconf.Resources.CPUQuota = attr.value + case "kernel-memory": + hostconf.Resources.KernelMemory = attr.value + case "memory": + hostconf.Resources.Memory = attr.value + case "memory-reservation": + hostconf.Resources.MemoryReservation = attr.value + case "memory-swap": + hostconf.Resources.MemorySwap = attr.value + case "memory-swappiness": + val := attr.value + hostconf.Resources.MemorySwappiness = &val + case "blkio-weight": + hostconf.Resources.BlkioWeight = uint16(attr.value) + case "pids-limit": + val := attr.value + hostconf.Resources.PidsLimit = &val + + } } - // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ - Image: "basic/alpine", - Extra: args, // Cgroup arguments. - }, "sleep", "10000"); err != nil { - t.Fatalf("docker run failed: %v", err) + // Create container. + if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("create failed with: %v", err) } - // Lookup the relevant cgroup ID. - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) + // Start container. + if err := d.Start(ctx); err != nil { + t.Fatalf("start failed with: %v", err) } + + // Lookup the relevant cgroup ID. + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Check list of attributes defined above. @@ -239,7 +279,7 @@ func TestCgroup(t *testing.T) { t.Fatalf("failed to read %q: %v", path, err) } if got := strings.TrimSpace(string(out)); got != attr.want { - t.Errorf("arg: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.arg, attr.ctrl, attr.file, got, attr.want) + t.Errorf("field: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.field, attr.ctrl, attr.file, got, attr.want) } } @@ -257,7 +297,7 @@ func TestCgroup(t *testing.T) { "pids", "systemd", } - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("SandboxPid: %v", err) } @@ -269,29 +309,34 @@ func TestCgroup(t *testing.T) { } } -// TestCgroup sets cgroup options and checks that cgroup was properly configured. +// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's +// cgroups are created correctly relative to each other. func TestCgroupParent(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Construct a known cgroup name. parent := testutil.RandomID("runsc-") - if err := d.Spawn(dockerutil.RunOpts{ + conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{ Image: "basic/alpine", - Extra: []string{fmt.Sprintf("--cgroup-parent=%s", parent)}, - }, "sleep", "10000"); err != nil { - t.Fatalf("docker run failed: %v", err) + }, "sleep", "10000") + hostconf.Resources.CgroupParent = parent + + if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("create failed with: %v", err) } - // Extract the ID to look up the cgroup. - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) + if err := d.Start(ctx); err != nil { + t.Fatalf("start failed with: %v", err) } + + // Extract the ID to look up the cgroup. + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Check that sandbox is inside cgroup. - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("SandboxPid: %v", err) } diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go index a306132a4..58fcd6f08 100644 --- a/test/root/chroot_test.go +++ b/test/root/chroot_test.go @@ -16,6 +16,7 @@ package root import ( + "context" "fmt" "io/ioutil" "os/exec" @@ -30,16 +31,17 @@ import ( // TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned // up after the sandbox is destroyed. func TestChroot(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) } - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("Docker.SandboxPid(): %v", err) } @@ -75,14 +77,15 @@ func TestChroot(t *testing.T) { t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc") } - d.CleanUp() + d.CleanUp(ctx) } func TestChrootGofer(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) @@ -91,7 +94,7 @@ func TestChrootGofer(t *testing.T) { // It's tricky to find gofers. Get sandbox PID first, then find parent. From // parent get all immediate children, remove the sandbox, and everything else // are gofers. - sandPID, err := d.SandboxPid() + sandPID, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("Docker.SandboxPid(): %v", err) } diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index 732fae821..11ac5cb52 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -20,13 +20,14 @@ import ( "fmt" "io" "io/ioutil" - "log" "net/http" "os" "os/exec" "path" - "path/filepath" + "regexp" + "strconv" "strings" + "sync" "testing" "time" @@ -75,6 +76,8 @@ func SimpleSpec(name, image string, cmd []string, extra map[string]interface{}) // Log files are not deleted after root tests are run. Log to random // paths to ensure logs are fresh. "log_path": fmt.Sprintf("%s.log", testutil.RandomID(name)), + "stdin": false, + "tty": false, } if len(cmd) > 0 { // Omit if empty. s["command"] = cmd @@ -95,25 +98,29 @@ var Httpd = SimpleSpec("httpd", "basic/httpd", nil, nil) // TestCrictlSanity refers to b/112433158. func TestCrictlSanity(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox("default"), Httpd) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - // Look for the httpd page. - if err = httpGet(crictl, podID, "index.html"); err != nil { - t.Fatalf("failed to get page: %v", err) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), Httpd) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + // Look for the httpd page. + if err = httpGet(crictl, podID, "index.html"); err != nil { + t.Fatalf("failed to get page: %v", err) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } @@ -147,146 +154,179 @@ var HttpdMountPaths = SimpleSpec("httpd", "basic/httpd", nil, map[string]interfa // TestMountPaths refers to b/117635704. func TestMountPaths(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox("default"), HttpdMountPaths) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - // Look for the directory available at /test. - if err = httpGet(crictl, podID, "test"); err != nil { - t.Fatalf("failed to get page: %v", err) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), HttpdMountPaths) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + // Look for the directory available at /test. + if err = httpGet(crictl, podID, "test"); err != nil { + t.Fatalf("failed to get page: %v", err) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } // TestMountPaths refers to b/118728671. func TestMountOverSymlinks(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - - spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil) - podID, contID, err := crictl.StartPodAndContainer("basic/resolv", Sandbox("default"), spec) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") - if err != nil { - t.Fatalf("readlink failed: %v, out: %s", err, out) - } - if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { - t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) - } - - etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") - if err != nil { - t.Fatalf("cat failed: %v, out: %s", err, etc) - } - tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") - if err != nil { - t.Fatalf("cat failed: %v, out: %s", err, out) - } - if tmp != etc { - t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + + spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/resolv", Sandbox("default"), spec) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") + if err != nil { + t.Fatalf("readlink failed: %v, out: %s", err, out) + } + if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { + t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) + } + + etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") + if err != nil { + t.Fatalf("cat failed: %v, out: %s", err, etc) + } + tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") + if err != nil { + t.Fatalf("cat failed: %v, out: %s", err, out) + } + if tmp != etc { + t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } // TestHomeDir tests that the HOME environment variable is set for // Pod containers. func TestHomeDir(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + + // Note that container ID returned here is a sub-container. All Pod + // containers are sub-containers. The root container of the sandbox is the + // pause container. + t.Run("sub-container", func(t *testing.T) { + contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("subcont-sandbox"), contSpec) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + out, err := crictl.Logs(contID) + if err != nil { + t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } + + // Stop everything; note that the pod may have already stopped. + crictl.StopPodAndContainer(podID, contID) + }) + + // Tests that HOME is set for the exec process. + t.Run("exec", func(t *testing.T) { + contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("exec-sandbox"), contSpec) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) + }) } - defer cleanup() - - // Note that container ID returned here is a sub-container. All Pod - // containers are sub-containers. The root container of the sandbox is the - // pause container. - t.Run("sub-container", func(t *testing.T) { - contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil) - podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox("subcont-sandbox"), contSpec) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - out, err := crictl.Logs(contID) - if err != nil { - t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) - } - }) - - // Tests that HOME is set for the exec process. - t.Run("exec", func(t *testing.T) { - contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil) - podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox("exec-sandbox"), contSpec) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") - if err != nil { - t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) - } - }) } -// containerdConfigTemplate is a .toml config for containerd. It contains a -// formatting verb so the runtime field can be set via fmt.Sprintf. -const containerdConfigTemplate = ` +const containerdRuntime = "runsc" + +const v1Template = ` disabled_plugins = ["restart"] +[plugins.cri] + disable_tcp_service = true [plugins.linux] - runtime = "%s" - runtime_root = "/tmp/test-containerd/runsc" - shim = "/usr/local/bin/gvisor-containerd-shim" + shim = "%s" shim_debug = true - -[plugins.cri.containerd.runtimes.runsc] +[plugins.cri.containerd.runtimes.` + containerdRuntime + `] runtime_type = "io.containerd.runtime.v1.linux" runtime_engine = "%s" + runtime_root = "%s/root/runsc" ` +const v2Template = ` +disabled_plugins = ["restart"] +[plugins.cri] + disable_tcp_service = true +[plugins.linux] + shim_debug = true +[plugins.cri.containerd.runtimes.` + containerdRuntime + `] + runtime_type = "io.containerd.` + containerdRuntime + `.v1" +[plugins.cri.containerd.runtimes.` + containerdRuntime + `.options] + TypeUrl = "io.containerd.` + containerdRuntime + `.v1.options" +` + +const ( + // v1 is the containerd API v1. + v1 string = "v1" + + // v1 is the containerd API v21. + v2 string = "v2" +) + +// allVersions is the set of known versions. +var allVersions = []string{v1, v2} + // setup sets up before a test. Specifically it: // * Creates directories and a socket for containerd to utilize. // * Runs containerd and waits for it to reach a "ready" state for testing. // * Returns a cleanup function that should be called at the end of the test. -func setup(t *testing.T) (*criutil.Crictl, func(), error) { +func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) { // Create temporary containerd root and state directories, and a socket // via which crictl and containerd communicate. containerdRoot, err := ioutil.TempDir(testutil.TmpDir(), "containerd-root") @@ -295,13 +335,43 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { } cu := cleanup.Make(func() { os.RemoveAll(containerdRoot) }) defer cu.Clean() + t.Logf("Using containerd root: %s", containerdRoot) containerdState, err := ioutil.TempDir(testutil.TmpDir(), "containerd-state") if err != nil { t.Fatalf("failed to create containerd state: %v", err) } cu.Add(func() { os.RemoveAll(containerdState) }) - sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock") + t.Logf("Using containerd state: %s", containerdState) + + sockDir, err := ioutil.TempDir(testutil.TmpDir(), "containerd-sock") + if err != nil { + t.Fatalf("failed to create containerd socket directory: %v", err) + } + cu.Add(func() { os.RemoveAll(sockDir) }) + sockAddr := path.Join(sockDir, "test.sock") + t.Logf("Using containerd socket: %s", sockAddr) + + // Extract the containerd version. + versionCmd := exec.Command(getContainerd(), "-v") + out, err := versionCmd.CombinedOutput() + if err != nil { + t.Fatalf("error extracting containerd version: %v (%s)", err, string(out)) + } + r := regexp.MustCompile(" v([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(string(out)) + if len(vs) != 4 { + t.Fatalf("error unexpected version string: %s", string(out)) + } + major, err := strconv.ParseUint(vs[1], 10, 64) + if err != nil { + t.Fatalf("error parsing containerd major version: %v (%s)", err, string(out)) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + t.Fatalf("error parsing containerd minor version: %v (%s)", err, string(out)) + } + t.Logf("Using containerd version: %d.%d", major, minor) // We rewrite a configuration. This is based on the current docker // configuration for the runtime under test. @@ -309,28 +379,97 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { if err != nil { t.Fatalf("error discovering runtime path: %v", err) } - config, configCleanup, err := testutil.WriteTmpFile("containerd-config", fmt.Sprintf(containerdConfigTemplate, runtime, runtime)) + t.Logf("Using runtime: %v", runtime) + + // Construct a PATH that includes the runtime directory. This is + // because the shims will be installed there, and containerd may infer + // the binary name and search the PATH. + runtimeDir := path.Dir(runtime) + modifiedPath := os.Getenv("PATH") + if modifiedPath != "" { + modifiedPath = ":" + modifiedPath // We prepend below. + } + modifiedPath = path.Dir(getContainerd()) + modifiedPath + modifiedPath = runtimeDir + ":" + modifiedPath + t.Logf("Using PATH: %v", modifiedPath) + + var ( + config string + runpArgs []string + ) + switch version { + case v1: + // This is only supported less than 1.3. + if major > 1 || (major == 1 && minor >= 3) { + t.Skipf("skipping unsupported containerd (want less than 1.3, got %d.%d)", major, minor) + } + + // We provide the shim, followed by the runtime, and then a + // temporary root directory. + config = fmt.Sprintf(v1Template, criutil.ResolvePath("gvisor-containerd-shim"), runtime, containerdRoot) + case v2: + // This is only supported past 1.2. + if major < 1 || (major == 1 && minor <= 1) { + t.Skipf("skipping incompatible containerd (want at least 1.2, got %d.%d)", major, minor) + } + + // The runtime is provided via parameter. Note that the v2 shim + // binary name is always containerd-shim-* so we don't actually + // care about the docker runtime name. + config = v2Template + default: + t.Fatalf("unknown version: %s", version) + } + t.Logf("Using config: %s", config) + + // Generate the configuration for the test. + configFile, configCleanup, err := testutil.WriteTmpFile("containerd-config", config) if err != nil { t.Fatalf("failed to write containerd config") } cu.Add(configCleanup) // Start containerd. - cmd := exec.Command(getContainerd(), - "--config", config, + args := []string{ + getContainerd(), + "--config", configFile, "--log-level", "debug", "--root", containerdRoot, "--state", containerdState, - "--address", sockAddr) + "--address", sockAddr, + } + t.Logf("Using args: %s", strings.Join(args, " ")) + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = append(os.Environ(), "PATH="+modifiedPath) + + // Include output in logs. + stderrPipe, err := cmd.StderrPipe() + if err != nil { + t.Fatalf("failed to create stderr pipe: %v", err) + } + cu.Add(func() { stderrPipe.Close() }) + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + t.Fatalf("failed to create stdout pipe: %v", err) + } + cu.Add(func() { stdoutPipe.Close() }) + var ( + wg sync.WaitGroup + stderr bytes.Buffer + stdout bytes.Buffer + ) startupR, startupW := io.Pipe() - defer startupR.Close() - defer startupW.Close() - stderr := &bytes.Buffer{} - stdout := &bytes.Buffer{} - cmd.Stderr = io.MultiWriter(startupW, stderr) - cmd.Stdout = io.MultiWriter(startupW, stdout) + wg.Add(2) + go func() { + defer wg.Done() + io.Copy(io.MultiWriter(startupW, &stderr), stderrPipe) + }() + go func() { + defer wg.Done() + io.Copy(io.MultiWriter(startupW, &stdout), stdoutPipe) + }() cu.Add(func() { - // Log output in case of failure. + wg.Wait() t.Logf("containerd stdout: %s", stdout.String()) t.Logf("containerd stderr: %s", stderr.String()) }) @@ -345,13 +484,17 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { t.Fatalf("failed to start containerd: %v", err) } + // Discard all subsequent data. + go io.Copy(ioutil.Discard, startupR) + + // Create the crictl interface. + cc := criutil.NewCrictl(t, sockAddr, runpArgs) + cu.Add(cc.CleanUp) + // Kill must be the last cleanup (as it will be executed first). - cc := criutil.NewCrictl(t, sockAddr) cu.Add(func() { - cc.CleanUp() // Remove tmp files, etc. - if err := testutil.KillCommand(cmd); err != nil { - log.Printf("error killing containerd: %v", err) - } + // Best effort: ignore errors. + testutil.KillCommand(cmd) }) return cc, cu.Release(), nil diff --git a/test/root/root.go b/test/root/root.go index 0f1d29faf..441fa5e2e 100644 --- a/test/root/root.go +++ b/test/root/root.go @@ -17,5 +17,5 @@ // docker, containerd, and crictl installed. To run these tests from the // project root directory: // -// ./scripts/root_tests.sh +// make root-tests package root diff --git a/test/runner/BUILD b/test/runner/BUILD index 6833c9986..582d2946d 100644 --- a/test/runner/BUILD +++ b/test/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary") +load("//tools:defs.bzl", "bzl_library", "go_binary") package(licenses = ["notice"]) @@ -16,7 +16,14 @@ go_binary( "//runsc/specutils", "//test/runner/gtest", "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 921e499be..248053dc3 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -61,7 +61,9 @@ def _syscall_test( file_access = "exclusive", overlay = False, add_uds_tree = False, - vfs2 = False): + vfs2 = False, + fuse = False, + debug = True): # Prepend "runsc" to non-native platform names. full_platform = platform if platform == "native" else "runsc_" + platform @@ -73,6 +75,8 @@ def _syscall_test( name += "_overlay" if vfs2: name += "_vfs2" + if fuse: + name += "_fuse" if network != "none": name += "_" + network + "net" @@ -107,6 +111,9 @@ def _syscall_test( "--overlay=" + str(overlay), "--add-uds-tree=" + str(add_uds_tree), "--vfs2=" + str(vfs2), + "--fuse=" + str(fuse), + "--strace=" + str(debug), + "--debug=" + str(debug), ] # Call the rule above. @@ -128,7 +135,9 @@ def syscall_test( add_overlay = False, add_uds_tree = False, add_hostinet = False, - vfs2 = False, + vfs2 = True, + fuse = False, + debug = True, tags = None): """syscall_test is a macro that will create targets for all platforms. @@ -145,31 +154,12 @@ def syscall_test( if not tags: tags = [] - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "native", - use_tmpfs = False, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - for (platform, platform_tags) in platforms.items(): - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = platform, - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = platform_tags + tags, - ) - vfs2_tags = list(tags) if vfs2: # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2 vfs2_tags.append("vfs2") + if fuse: + vfs2_tags.append("fuse") else: # Don't automatically run tests tests not yet passing. @@ -186,9 +176,33 @@ def syscall_test( add_uds_tree = add_uds_tree, tags = platforms[default_platform] + vfs2_tags, vfs2 = True, + fuse = fuse, ) + if fuse: + # Only generate *_vfs2_fuse target if fuse parameter is enabled. + return + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "native", + use_tmpfs = False, + add_uds_tree = add_uds_tree, + tags = list(tags), + ) + + for (platform, platform_tags) in platforms.items(): + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platform_tags + tags, + ) - # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests. if add_overlay: _syscall_test( test = test, @@ -201,6 +215,23 @@ def syscall_test( overlay = True, ) + # TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests. + overlay_vfs2_tags = list(vfs2_tags) + overlay_vfs2_tags.append("manual") + overlay_vfs2_tags.append("noguitar") + overlay_vfs2_tags.append("notap") + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + overlay_vfs2_tags, + overlay = True, + vfs2 = True, + ) + if add_hostinet: _syscall_test( test = test, diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index 869169ad5..e4445e01b 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -146,10 +146,13 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) } - out = []byte(strings.Trim(string(out), "\n")) + benches := strings.Trim(string(out), "\n") + if len(benches) == 0 { + return t, nil + } // Parse benchmark output. - for _, line := range strings.Split(string(out), "\n") { + for _, line := range strings.Split(benches, "\n") { // Strip comments. line = strings.Split(line, "#")[0] @@ -163,6 +166,5 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes benchmark: true, }) } - return t, nil } diff --git a/test/runner/runner.go b/test/runner/runner.go index 5456e46a6..22d535f8d 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -30,6 +30,7 @@ import ( "time" specs "github.com/opencontainers/runtime-spec/specs-go" + "github.com/syndtr/gocapability/capability" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/test/testutil" @@ -47,6 +48,7 @@ var ( fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") vfs2 = flag.Bool("vfs2", false, "enable VFS2") + fuse = flag.Bool("fuse", false, "enable FUSE") parallel = flag.Bool("parallel", false, "run tests in parallel") runscPath = flag.String("runsc", "", "path to runsc binary") @@ -104,6 +106,16 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.SysProcAttr = &syscall.SysProcAttr{} + + if specutils.HasCapabilities(capability.CAP_SYS_ADMIN) { + cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWUTS + } + + if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { + cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWNET + } + if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) @@ -149,6 +161,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { } if *vfs2 { args = append(args, "-vfs2") + if *fuse { + args = append(args, "-fuse") + } } if *debug { args = append(args, "-debug", "-log-packets=true") @@ -160,13 +175,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { args = append(args, "-fsgofer-host-uds") } - undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR") - if ok { - tdir := filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(tdir, 0755); err != nil { + testLogDir := "" + if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + // Create log directory dedicated for this test. + testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(testLogDir, 0755); err != nil { return fmt.Errorf("could not create test dir: %v", err) } - debugLogDir, err := ioutil.TempDir(tdir, "runsc") + debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) } @@ -215,10 +231,10 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { dArgs := append([]string{}, args...) dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) go func(dArgs []string) { - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + debug := exec.Command(*runscPath, dArgs...) + debug.Stdout = os.Stdout + debug.Stderr = os.Stderr + debug.Run() done <- true }(dArgs) @@ -233,17 +249,17 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { dArgs = append(args, "debug", fmt.Sprintf("--signal=%d", syscall.SIGTERM), id) - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + signal := exec.Command(*runscPath, dArgs...) + signal.Stdout = os.Stdout + signal.Stderr = os.Stderr + signal.Run() }() err = cmd.Run() - if err == nil { + if err == nil && len(testLogDir) > 0 { // If the test passed, then we erase the log directory. This speeds up // uploading logs in continuous integration & saves on disk space. - os.RemoveAll(undeclaredOutputsDir) + os.RemoveAll(testLogDir) } return err @@ -358,6 +374,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { vfsVar := "GVISOR_VFS" if *vfs2 { env = append(env, vfsVar+"=VFS2") + fuseVar := "FUSE_ENABLED" + if *fuse { + env = append(env, fuseVar+"=TRUE") + } else { + env = append(env, fuseVar+"=FALSE") + } } else { env = append(env, vfsVar+"=VFS1") } diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 022de5ff7..22b526f59 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,33 +1,46 @@ +load("//tools:defs.bzl", "bzl_library") load("//test/runtimes:defs.bzl", "runtime_test") package(licenses = ["notice"]) runtime_test( name = "go1.12", - exclude_file = "exclude_go1.12.csv", + exclude_file = "exclude/go1.12.csv", lang = "go", + shard_count = 8, ) runtime_test( name = "java11", - exclude_file = "exclude_java11.csv", + batch = 100, + exclude_file = "exclude/java11.csv", lang = "java", + shard_count = 16, ) runtime_test( name = "nodejs12.4.0", - exclude_file = "exclude_nodejs12.4.0.csv", + exclude_file = "exclude/nodejs12.4.0.csv", lang = "nodejs", + shard_count = 8, ) runtime_test( name = "php7.3.6", - exclude_file = "exclude_php7.3.6.csv", + exclude_file = "exclude/php7.3.6.csv", lang = "php", + shard_count = 8, ) runtime_test( name = "python3.7.3", - exclude_file = "exclude_python3.7.3.csv", + exclude_file = "exclude/python3.7.3.csv", lang = "python", + shard_count = 8, +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md new file mode 100644 index 000000000..9dda1a728 --- /dev/null +++ b/test/runtimes/README.md @@ -0,0 +1,62 @@ +# gVisor Runtime Tests + +App Engine uses gvisor to sandbox application containers. The runtime tests aim +to test `runsc` compatibility with these +[standard runtimes](https://cloud.google.com/appengine/docs/standard/runtimes). +The test itself runs the language-defined tests inside the sandboxed standard +runtime container. + +Note: [Ruby runtime](https://cloud.google.com/appengine/docs/standard/ruby) is +currently in beta mode and so we do not run tests for it yet. + +### Testing Locally + +To run runtime tests individually from a given runtime, use the following table. + +Language | Version | Download Image | Run Test(s) +-------- | ------- | ------------------------------------------- | ----------- +Go | 1.12 | `make -C images load-runtimes_go1.12` | If the test name ends with `.go`, it is an on-disk test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 ( cd /usr/local/go/test ; go run run.go -v -- <TEST_NAME>... )` <br> Otherwise it is a tool test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 go tool dist test -v -no-rebuild ^TEST1$\|^TEST2$...` +Java | 11 | `make -C images load-runtimes_java11` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/java11 jtreg -agentvm -dir:/root/test/jdk -noreport -timeoutFactor:20 -verbose:summary <TEST_NAME>...` +NodeJS | 12.4.0 | `make -C images load-runtimes_nodejs12.4.0` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/nodejs12.4.0 python tools/test.py --timeout=180 <TEST_NAME>...` +Php | 7.3.6 | `make -C images load-runtimes_php7.3.6` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/php7.3.6 make test "TESTS=<TEST_NAME>..."` +Python | 3.7.3 | `make -C images load-runtimes_python3.7.3` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/python3.7.3 ./python -m test <TEST_NAME>...` + +To run an entire runtime test locally, use the following table. + +Note: java runtime test take 1+ hours with 16 cores. + +Language | Version | Running the test suite +-------- | ------- | ---------------------------------------- +Go | 1.12 | `make go1.12-runtime-tests{_vfs2}` +Java | 11 | `make java11-runtime-tests{_vfs2}` +NodeJS | 12.4.0 | `make nodejs12.4.0-runtime-tests{_vfs2}` +Php | 7.3.6 | `make php7.3.6-runtime-tests{_vfs2}` +Python | 3.7.3 | `make python3.7.3-runtime-tests{_vfs2}` + +#### Clean Up + +Sometimes when runtime tests fail or when the testing container itself crashes +unexpectedly, the containers are not removed or sometimes do not even exit. This +can cause some docker commands like `docker system prune` to hang forever. + +Here are some helpful commands (should be executed in order): + +```bash +docker ps -a # Lists all docker processes; useful when investigating hanging containers. +docker kill $(docker ps -a -q) # Kills all running containers. +docker rm $(docker ps -a -q) # Removes all exited containers. +docker system prune # Remove unused data. +``` + +### Testing Infrastructure + +There are 3 components to this tests infrastructure: + +- [`runner`](runner) - This is the test entrypoint. This is the binary is + invoked by `bazel test`. The runner spawns the target runtime container + using `runsc` and then copies over the `proctor` binary into the container. +- [`proctor`](proctor) - This binary acts as our agent inside the container + which communicates with the runner and actually executes tests. +- [`exclude`](exclude) - Holds a CSV file for each language runtime containing + the full path of tests that should be excluded from running along with a + reason for exclusion. diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl index dc3667f05..702522d86 100644 --- a/test/runtimes/defs.bzl +++ b/test/runtimes/defs.bzl @@ -9,6 +9,8 @@ def _runtime_test_impl(ctx): ctx.attr.lang, "--image", ctx.attr.image, + "--batch", + str(ctx.attr.batch), ] if ctx.attr.exclude_file: args += [ @@ -20,7 +22,7 @@ def _runtime_test_impl(ctx): runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) runner_content = "\n".join([ "#!/bin/bash", - "%s %s\n" % (ctx.files._runner[0].short_path, " ".join(args)), + "%s %s $@\n" % (ctx.files._runner[0].short_path, " ".join(args)), ]) ctx.actions.write(runner, runner_content, is_executable = True) @@ -47,11 +49,19 @@ _runtime_test = rule( mandatory = False, allow_single_file = True, ), + "batch": attr.int( + default = 50, + mandatory = False, + ), "_runner": attr.label( default = "//test/runtimes/runner:runner", + executable = True, + cfg = "target", ), "_proctor": attr.label( default = "//test/runtimes/proctor:proctor", + executable = True, + cfg = "target", ), }, test = True, @@ -65,6 +75,7 @@ def runtime_test(name, **kwargs): "local", "manual", ], + size = "enormous", **kwargs ) diff --git a/test/runtimes/exclude/go1.12.csv b/test/runtimes/exclude/go1.12.csv new file mode 100644 index 000000000..81e02cf64 --- /dev/null +++ b/test/runtimes/exclude/go1.12.csv @@ -0,0 +1,13 @@ +test name,bug id,comment +cgo_errors,,FLAKY +cgo_test,,FLAKY +go_test:cmd/go,,FLAKY +go_test:net,b/162473575,setsockopt: protocol not available. +go_test:os,b/118780122,we have a pollable filesystem but that's a surprise +go_test:os/signal,b/118780860,/dev/pts not properly supported. Also being tracked in b/29356795. +go_test:runtime,b/118782341,sigtrap not reported or caught or something. Also being tracked in b/33003106. +go_test:syscall,b/118781998,bad bytes -- bad mem addr; FcntlFlock(F_GETLK) not supported. +runtime:cpu124,b/118778254,segmentation fault +test:0_1,,FLAKY +testcarchive,b/118782924,no sigpipe +testshared,,FLAKY diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv new file mode 100644 index 000000000..f779df8d5 --- /dev/null +++ b/test/runtimes/exclude/java11.csv @@ -0,0 +1,210 @@ +test name,bug id,comment +com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker +com/sun/jdi/InvokeHangTest.java,https://bugs.openjdk.java.net/browse/JDK-8218463, +com/sun/jdi/NashornPopFrameTest.java,, +com/sun/jdi/ProcessAttachTest.java,, +com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker +com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, +com/sun/management/ThreadMXBean/ThreadCpuTimeArray.java,,Test assumes high CPU clock precision +com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, +com/sun/tools/attach/AttachSelf.java,, +com/sun/tools/attach/BasicTests.java,, +com/sun/tools/attach/PermissionTest.java,, +com/sun/tools/attach/StartManagementAgent.java,, +com/sun/tools/attach/TempDirTest.java,, +com/sun/tools/attach/modules/Driver.java,, +java/lang/Character/CheckScript.java,,Fails in Docker +java/lang/Character/CheckUnicode.java,,Fails in Docker +java/lang/Class/GetPackageBootLoaderChildLayer.java,, +java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker +java/lang/module/ModuleDescriptorTest.java,, +java/lang/String/nativeEncoding/StringPlatformChars.java,, +java/net/CookieHandler/B6791927.java,,java.lang.RuntimeException: Expiration date shouldn't be 0 +java/net/ipv6tests/TcpTest.java,,java.net.ConnectException: Connection timed out (Connection timed out) +java/net/ipv6tests/UdpTest.java,,Times out +java/net/Inet6Address/B6558853.java,,Times out +java/net/InetAddress/CheckJNI.java,,java.net.ConnectException: Connection timed out (Connection timed out) +java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, +java/net/MulticastSocket/B6425815.java,,java.net.SocketException: Protocol not available (Error getting socket option) +java/net/MulticastSocket/B6427403.java,,java.net.SocketException: Protocol not available +java/net/MulticastSocket/MulticastTTL.java,, +java/net/MulticastSocket/NetworkInterfaceEmptyGetInetAddressesTest.java,,java.net.SocketException: Protocol not available (Error getting socket option) +java/net/MulticastSocket/NoLoopbackPackets.java,,java.net.SocketException: Protocol not available +java/net/MulticastSocket/Promiscuous.java,, +java/net/MulticastSocket/SetLoopbackMode.java,, +java/net/MulticastSocket/SetTTLAndGetTTL.java,, +java/net/MulticastSocket/Test.java,, +java/net/MulticastSocket/TestDefaults.java,, +java/net/MulticastSocket/TimeToLive.java,, +java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, +java/net/Socket/LinkLocal.java,,java.net.SocketTimeoutException: Receive timed out +java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported +java/net/Socket/UrgentDataTest.java,b/111515323, +java/net/SocketOption/OptionsTest.java,,Fails in Docker +java/net/SocketPermission/SocketPermissionTest.java,, +java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker +java/net/httpclient/RequestBuilderTest.java,,Fails in Docker +java/nio/channels/DatagramChannel/BasicMulticastTests.java,, +java/nio/channels/DatagramChannel/SocketOptionTests.java,,java.net.SocketException: Invalid argument +java/nio/channels/DatagramChannel/UseDGWithIPv6.java,, +java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker +java/nio/channels/FileChannel/directio/PwriteDirect.java,,java.io.IOException: Invalid argument +java/nio/channels/Selector/OutOfBand.java,, +java/nio/channels/Selector/SelectWithConsumer.java,,Flaky +java/nio/channels/ServerSocketChannel/SocketOptionTests.java,, +java/nio/channels/SocketChannel/LingerOnClose.java,, +java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, +java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker +java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, +java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, +java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker +java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker +java/util/Calendar/JapaneseEraNameTest.java,, +java/util/Currency/CurrencyTest.java,,Fails in Docker +java/util/Currency/ValidateISO4217.java,,Fails in Docker +java/util/EnumSet/BogusEnumSet.java,,"java.io.InvalidClassException: java.util.EnumSet; local class incompatible: stream classdesc serialVersionUID = -2409567991088730183, local class serialVersionUID = 1009687484059888093" +java/util/Locale/Bug8040211.java,,java.lang.RuntimeException: Failed. +java/util/Locale/LSRDataTest.java,, +java/util/Properties/CompatibilityTest.java,,"java.lang.RuntimeException: jdk.internal.org.xml.sax.SAXParseException; Internal DTD subset is not allowed. The Properties XML document must have the following DOCTYPE declaration: <!DOCTYPE properties SYSTEM ""http://java.sun.com/dtd/properties.dtd"">" +java/util/ResourceBundle/Control/XMLResourceBundleTest.java,,java.util.MissingResourceException: Can't find bundle for base name XmlRB locale +java/util/ResourceBundle/modules/xmlformat/xmlformat.sh,,Timeout reached: 60000. Process is not alive! +java/util/TimeZone/TimeZoneTest.java,,Uncaught exception thrown in test method TestShortZoneIDs +java/util/concurrent/locks/Lock/TimedAcquireLeak.java,, +java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker +java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,, +java/util/logging/TestLoggerWeakRefLeak.java,, +java/util/spi/ResourceBundleControlProvider/UserDefaultControlTest.java,,java.util.MissingResourceException: Can't find bundle for base name com.foo.XmlRB locale +javax/imageio/AppletResourceTest.java,, +javax/imageio/plugins/jpeg/JPEGsNotAcceleratedTest.java,,java.awt.HeadlessException: No X11 DISPLAY variable was set but this program performed an operation which requires it. +javax/management/security/HashedPasswordFileTest.java,, +javax/net/ssl/DTLS/DTLSBufferOverflowUnderflowTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSHandshakeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSHandshakeWithReplicatedPacketsTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSIncorrectAppDataTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSMFLNTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/DTLS/DTLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSSequenceNumberTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10BufferOverflowUnderflowTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10DataExchangeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10EnginesClosureTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10HandshakeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10HandshakeWithReplicatedPacketsTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10IncorrectAppDataTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10MFLNTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10NotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10RehandshakeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10RehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10RehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10SequenceNumberTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10UnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker +javax/net/ssl/TLS/TLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/TLS/TLSHandshakeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSMFLNTest.java,,Compilation failed +javax/net/ssl/TLS/TLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/TLS/TLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/TLS/TLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSHandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSMFLNTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/TLSv1/TLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSHandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSMFLNTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/TLSv11/TLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/TLSv12/TLSEnginesClosureTest.java,,Compilation failed +javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,, +jdk/jfr/cmd/TestHelp.java,,java.lang.RuntimeException: 'Available commands are:' missing from stdout/stderr +jdk/jfr/cmd/TestPrint.java,,Missing file' missing from stdout/stderr +jdk/jfr/cmd/TestPrintDefault.java,,java.lang.RuntimeException: 'JVMInformation' missing from stdout/stderr +jdk/jfr/cmd/TestPrintJSON.java,,javax.script.ScriptException: <eval>:1:17 Expected an operand but found eof var jsonObject = ^ in <eval> at line number 1 at column number 17 +jdk/jfr/cmd/TestPrintXML.java,,org.xml.sax.SAXParseException; lineNumber: 1; columnNumber: 1; Premature end of file. +jdk/jfr/cmd/TestReconstruct.java,,java.lang.RuntimeException: 'Too few arguments' missing from stdout/stderr +jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr +jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr +jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event +jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default' +jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold +jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base +jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, +jdk/jfr/event/runtime/TestThreadParkEvent.java,, +jdk/jfr/event/sampling/TestNative.java,, +jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,, +jdk/jfr/jcmd/TestJcmdConfigure.java,, +jdk/jfr/jcmd/TestJcmdDump.java,, +jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,, +jdk/jfr/jcmd/TestJcmdDumpLimited.java,, +jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,, +jdk/jfr/jcmd/TestJcmdLegacy.java,, +jdk/jfr/jcmd/TestJcmdSaveToFile.java,, +jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,, +jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,, +jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,, +jdk/jfr/jcmd/TestJcmdStartStopDefault.java,, +jdk/jfr/jcmd/TestJcmdStartWithOptions.java,, +jdk/jfr/jcmd/TestJcmdStartWithSettings.java,, +jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,, +jdk/jfr/jvm/TestGetAllEventClasses.java,,Compilation failed +jdk/jfr/jvm/TestJfrJavaBase.java,, +jdk/jfr/startupargs/TestStartRecording.java,, +jdk/modules/incubator/ImageModules.java,, +jdk/net/Sockets/ExtOptionTest.java,, +jdk/net/Sockets/QuickAckTest.java,, +lib/security/cacerts/VerifyCACerts.java,, +sun/management/jmxremote/bootstrap/CustomLauncherTest.java,, +sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,, +sun/management/jmxremote/bootstrap/LocalManagementTest.java,, +sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,, +sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, +sun/management/jmxremote/startstop/JMXStartStopTest.java,, +sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, +sun/management/jmxremote/startstop/JMXStatusTest.java,, +sun/management/jdp/JdpDefaultsTest.java,, +sun/management/jdp/JdpJmxRemoteDynamicPortTest.java,, +sun/management/jdp/JdpOffTest.java,, +sun/management/jdp/JdpSpecificAddressTest.java,, +sun/text/resources/LocaleDataTest.java,, +sun/tools/jcmd/TestJcmdSanity.java,, +sun/tools/jhsdb/AlternateHashingTest.java,, +sun/tools/jhsdb/BasicLauncherTest.java,, +sun/tools/jhsdb/HeapDumpTest.java,, +sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,, +sun/tools/jinfo/BasicJInfoTest.java,, +sun/tools/jinfo/JInfoTest.java,, +sun/tools/jmap/BasicJMapTest.java,, +sun/tools/jstack/BasicJStackTest.java,, +sun/tools/jstack/DeadlockDetectionTest.java,, +sun/tools/jstatd/TestJstatdExternalRegistry.java,, +sun/tools/jstatd/TestJstatdPort.java,,Flaky +sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky +sun/util/calendar/zi/TestZoneInfo310.java,, +tools/jar/modularJar/Basic.java,, +tools/jar/multiRelease/Basic.java,, +tools/jimage/JImageExtractTest.java,, +tools/jimage/JImageTest.java,, +tools/jlink/JLinkTest.java,, +tools/jlink/plugins/IncludeLocalesPluginTest.java,, +tools/jmod/hashes/HashesTest.java,, +tools/launcher/BigJar.java,b/111611473, +tools/launcher/HelpFlagsTest.java,,java.lang.AssertionError: HelpFlagsTest failed: Tool jfr not covered by this test. Add specification to jdkTools array! +tools/launcher/VersionCheck.java,,java.lang.AssertionError: VersionCheck failed: testToolVersion: [jfr]; +tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,, diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv new file mode 100644 index 000000000..ba993814f --- /dev/null +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -0,0 +1,58 @@ +test name,bug id,comment +async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29 +benchmark/test-benchmark-fs.js,, +benchmark/test-benchmark-napi.js,, +doctool/test-make-doc.js,b/68848110,Expected to fail. +internet/test-dgram-multicast-set-interface-lo.js,b/162798882, +internet/test-doctool-versions.js,, +internet/test-uv-threadpool-schedule.js,, +parallel/test-cluster-dgram-reuse.js,b/64024294, +parallel/test-dgram-bind-fd.js,b/132447356, +parallel/test-dgram-socket-buffer-size.js,b/68847921, +parallel/test-dns-channel-timeout.js,b/161893056, +parallel/test-fs-access.js,, +parallel/test-fs-watchfile.js,,Flaky - File already exists error +parallel/test-fs-write-stream.js,b/166819807,Flaky +parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky +parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky +parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 +parallel/test-os.js,b/63997097, +parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE +parallel/test-process-uid-gid.js,, +parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE +parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE +pseudo-tty/test-assert-colors.js,b/162801321, +pseudo-tty/test-assert-no-color.js,b/162801321, +pseudo-tty/test-assert-position-indicator.js,b/162801321, +pseudo-tty/test-async-wrap-getasyncid-tty.js,b/162801321, +pseudo-tty/test-fatal-error.js,b/162801321, +pseudo-tty/test-handle-wrap-isrefed-tty.js,b/162801321, +pseudo-tty/test-readable-tty-keepalive.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset-process-exit.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset-signal.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset.js,b/162801321, +pseudo-tty/test-stderr-stdout-handle-sigwinch.js,b/162801321, +pseudo-tty/test-stdout-read.js,b/162801321, +pseudo-tty/test-tty-color-support.js,b/162801321, +pseudo-tty/test-tty-isatty.js,b/162801321, +pseudo-tty/test-tty-stdin-call-end.js,b/162801321, +pseudo-tty/test-tty-stdin-end.js,b/162801321, +pseudo-tty/test-stdin-write.js,b/162801321, +pseudo-tty/test-tty-stdout-end.js,b/162801321, +pseudo-tty/test-tty-stdout-resize.js,b/162801321, +pseudo-tty/test-tty-stream-constructors.js,b/162801321, +pseudo-tty/test-tty-window-size.js,b/162801321, +pseudo-tty/test-tty-wrap.js,b/162801321, +pummel/test-heapdump-http2.js,,Flaky +pummel/test-net-pingpong.js,, +pummel/test-vm-memleak.js,b/162799436, +pummel/test-watch-file.js,,Flaky - Timeout +sequential/test-child-process-pass-fd.js,b/63926391,Flaky +sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE +sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout +tick-processor/test-tick-processor-builtin.js,, diff --git a/test/runtimes/exclude/php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv new file mode 100644 index 000000000..a73f3bcfb --- /dev/null +++ b/test/runtimes/exclude/php7.3.6.csv @@ -0,0 +1,46 @@ +test name,bug id,comment +ext/intl/tests/bug77895.phpt,, +ext/intl/tests/dateformat_bug65683_2.phpt,, +ext/mbstring/tests/bug76319.phpt,, +ext/mbstring/tests/bug76958.phpt,, +ext/mbstring/tests/bug77025.phpt,, +ext/mbstring/tests/bug77165.phpt,, +ext/mbstring/tests/bug77454.phpt,, +ext/mbstring/tests/mb_convert_encoding_leak.phpt,, +ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, +ext/session/tests/session_module_name_variation4.phpt,,Flaky +ext/session/tests/session_set_save_handler_class_018.phpt,, +ext/session/tests/session_set_save_handler_iface_003.phpt,, +ext/session/tests/session_set_save_handler_sid_001.phpt,, +ext/session/tests/session_set_save_handler_variation4.phpt,, +ext/standard/tests/file/disk.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_basic.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_error.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_variation.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_basic.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_error.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_variation.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/fopen_variation19.phpt,b/162894964, +ext/standard/tests/file/lstat_stat_variation14.phpt,,Flaky +ext/standard/tests/file/php_fd_wrapper_01.phpt,, +ext/standard/tests/file/php_fd_wrapper_02.phpt,, +ext/standard/tests/file/php_fd_wrapper_03.phpt,, +ext/standard/tests/file/php_fd_wrapper_04.phpt,, +ext/standard/tests/file/realpath_bug77484.phpt,b/162894969, +ext/standard/tests/file/rename_variation.phpt,b/68717309, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, +ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, +ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, +ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb +ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky +ext/standard/tests/streams/stream_socket_sendto.phpt,, +ext/standard/tests/strings/007.phpt,, +sapi/cli/tests/upload_2G.phpt,, +tests/output/stream_isatty_err.phpt,b/68720279, +tests/output/stream_isatty_in-err.phpt,b/68720282, +tests/output/stream_isatty_in-out-err.phpt,, +tests/output/stream_isatty_in-out.phpt,b/68720299, +tests/output/stream_isatty_out-err.phpt,b/68720311, +tests/output/stream_isatty_out.phpt,b/68720325, +Zend/tests/concat_003.phpt,b/162896021, diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv new file mode 100644 index 000000000..8760f8951 --- /dev/null +++ b/test/runtimes/exclude/python3.7.3.csv @@ -0,0 +1,21 @@ +test name,bug id,comment +test_asyncio,,Fails on Docker. +test_asyncore,b/162973328, +test_epoll,b/162983393, +test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. +test_httplib,b/163000009,OSError: [Errno 98] Address already in use +test_imaplib,b/162979661, +test_logging,b/162980079, +test_multiprocessing_fork,,Flaky. Sometimes times out. +test_multiprocessing_forkserver,,Flaky. Sometimes times out. +test_multiprocessing_main_handling,,Flaky. Sometimes times out. +test_multiprocessing_spawn,,Flaky. Sometimes times out. +test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted +test_pty,b/162979921, +test_readline,b/162980389,TestReadline hangs forever +test_resource,b/76174079, +test_selectors,b/76116849,OSError not raised with epoll +test_smtplib,b/162980434,unclosed sockets +test_signal,,Flaky - signal: alarm clock +test_socket,b/75983380, +test_subprocess,b/162980831, diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude_go1.12.csv deleted file mode 100644 index 8c8ae0c5d..000000000 --- a/test/runtimes/exclude_go1.12.csv +++ /dev/null @@ -1,16 +0,0 @@ -test name,bug id,comment -cgo_errors,,FLAKY -cgo_test,,FLAKY -go_test:cmd/go,,FLAKY -go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing -go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug. -go_test:os,b/118780122,we have a pollable filesystem but that's a surprise -go_test:os/signal,b/118780860,/dev/pts not properly supported -go_test:runtime,b/118782341,sigtrap not reported or caught or something -go_test:syscall,b/118781998,bad bytes -- bad mem addr -race,b/118782931,thread sanitizer. Works as intended: b/62219744. -runtime:cpu124,b/118778254,segmentation fault -test:0_1,,FLAKY -testasan,, -testcarchive,b/118782924,no sigpipe -testshared,,FLAKY diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude_java11.csv deleted file mode 100644 index c012e5a56..000000000 --- a/test/runtimes/exclude_java11.csv +++ /dev/null @@ -1,126 +0,0 @@ -test name,bug id,comment -com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker -com/sun/jdi/NashornPopFrameTest.java,, -com/sun/jdi/ProcessAttachTest.java,, -com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker -com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, -com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, -com/sun/tools/attach/AttachSelf.java,, -com/sun/tools/attach/BasicTests.java,, -com/sun/tools/attach/PermissionTest.java,, -com/sun/tools/attach/StartManagementAgent.java,, -com/sun/tools/attach/TempDirTest.java,, -com/sun/tools/attach/modules/Driver.java,, -java/lang/Character/CheckScript.java,,Fails in Docker -java/lang/Character/CheckUnicode.java,,Fails in Docker -java/lang/Class/GetPackageBootLoaderChildLayer.java,, -java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker -java/lang/String/nativeEncoding/StringPlatformChars.java,, -java/net/DatagramSocket/ReuseAddressTest.java,, -java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345, -java/net/Inet4Address/PingThis.java,, -java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, -java/net/MulticastSocket/MulticastTTL.java,, -java/net/MulticastSocket/Promiscuous.java,, -java/net/MulticastSocket/SetLoopbackMode.java,, -java/net/MulticastSocket/SetTTLAndGetTTL.java,, -java/net/MulticastSocket/Test.java,, -java/net/MulticastSocket/TestDefaults.java,, -java/net/MulticastSocket/TimeToLive.java,, -java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, -java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported -java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor -java/net/Socket/UrgentDataTest.java,b/111515323, -java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default -java/net/SocketOption/OptionsTest.java,,Fails in Docker -java/net/SocketOption/TcpKeepAliveTest.java,, -java/net/SocketPermission/SocketPermissionTest.java,, -java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker -java/net/httpclient/RequestBuilderTest.java,,Fails in Docker -java/net/httpclient/ShortResponseBody.java,, -java/net/httpclient/ShortResponseBodyWithRetry.java,, -java/nio/channels/AsyncCloseAndInterrupt.java,, -java/nio/channels/AsynchronousServerSocketChannel/Basic.java,, -java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable -java/nio/channels/DatagramChannel/BasicMulticastTests.java,, -java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker -java/nio/channels/DatagramChannel/UseDGWithIPv6.java,, -java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker -java/nio/channels/Selector/OutOfBand.java,, -java/nio/channels/Selector/SelectWithConsumer.java,,Flaky -java/nio/channels/ServerSocketChannel/SocketOptionTests.java,, -java/nio/channels/SocketChannel/LingerOnClose.java,, -java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, -java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker -java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, -java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, -java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker -java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker -java/util/Calendar/JapaneseEraNameTest.java,, -java/util/Currency/CurrencyTest.java,,Fails in Docker -java/util/Currency/ValidateISO4217.java,,Fails in Docker -java/util/Locale/LSRDataTest.java,, -java/util/concurrent/locks/Lock/TimedAcquireLeak.java,, -java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker -java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,, -java/util/logging/TestLoggerWeakRefLeak.java,, -javax/imageio/AppletResourceTest.java,, -javax/management/security/HashedPasswordFileTest.java,, -javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker -javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,, -jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, -jdk/jfr/event/runtime/TestThreadParkEvent.java,, -jdk/jfr/event/sampling/TestNative.java,, -jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,, -jdk/jfr/jcmd/TestJcmdConfigure.java,, -jdk/jfr/jcmd/TestJcmdDump.java,, -jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,, -jdk/jfr/jcmd/TestJcmdDumpLimited.java,, -jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,, -jdk/jfr/jcmd/TestJcmdLegacy.java,, -jdk/jfr/jcmd/TestJcmdSaveToFile.java,, -jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,, -jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,, -jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,, -jdk/jfr/jcmd/TestJcmdStartStopDefault.java,, -jdk/jfr/jcmd/TestJcmdStartWithOptions.java,, -jdk/jfr/jcmd/TestJcmdStartWithSettings.java,, -jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,, -jdk/jfr/jvm/TestJfrJavaBase.java,, -jdk/jfr/startupargs/TestStartRecording.java,, -jdk/modules/incubator/ImageModules.java,, -jdk/net/Sockets/ExtOptionTest.java,, -jdk/net/Sockets/QuickAckTest.java,, -lib/security/cacerts/VerifyCACerts.java,, -sun/management/jmxremote/bootstrap/CustomLauncherTest.java,, -sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,, -sun/management/jmxremote/bootstrap/LocalManagementTest.java,, -sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,, -sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, -sun/management/jmxremote/startstop/JMXStartStopTest.java,, -sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, -sun/management/jmxremote/startstop/JMXStatusTest.java,, -sun/text/resources/LocaleDataTest.java,, -sun/tools/jcmd/TestJcmdSanity.java,, -sun/tools/jhsdb/AlternateHashingTest.java,, -sun/tools/jhsdb/BasicLauncherTest.java,, -sun/tools/jhsdb/HeapDumpTest.java,, -sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,, -sun/tools/jinfo/BasicJInfoTest.java,, -sun/tools/jinfo/JInfoTest.java,, -sun/tools/jmap/BasicJMapTest.java,, -sun/tools/jstack/BasicJStackTest.java,, -sun/tools/jstack/DeadlockDetectionTest.java,, -sun/tools/jstatd/TestJstatdExternalRegistry.java,, -sun/tools/jstatd/TestJstatdPort.java,,Flaky -sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky -sun/util/calendar/zi/TestZoneInfo310.java,, -tools/jar/modularJar/Basic.java,, -tools/jar/multiRelease/Basic.java,, -tools/jimage/JImageExtractTest.java,, -tools/jimage/JImageTest.java,, -tools/jlink/JLinkTest.java,, -tools/jlink/plugins/IncludeLocalesPluginTest.java,, -tools/jmod/hashes/HashesTest.java,, -tools/launcher/BigJar.java,b/111611473, -tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,, diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude_nodejs12.4.0.csv deleted file mode 100644 index 4ab4e2927..000000000 --- a/test/runtimes/exclude_nodejs12.4.0.csv +++ /dev/null @@ -1,47 +0,0 @@ -test name,bug id,comment -benchmark/test-benchmark-fs.js,, -benchmark/test-benchmark-module.js,, -benchmark/test-benchmark-napi.js,, -doctool/test-make-doc.js,b/68848110,Expected to fail. -fixtures/test-error-first-line-offset.js,, -fixtures/test-fs-readfile-error.js,, -fixtures/test-fs-stat-sync-overflow.js,, -internet/test-dgram-broadcast-multi-process.js,, -internet/test-dgram-multicast-multi-process.js,, -internet/test-dgram-multicast-set-interface-lo.js,, -parallel/test-cluster-dgram-reuse.js,b/64024294, -parallel/test-dgram-bind-fd.js,b/132447356, -parallel/test-dgram-create-socket-handle-fd.js,b/132447238, -parallel/test-dgram-createSocket-type.js,b/68847739, -parallel/test-dgram-socket-buffer-size.js,b/68847921, -parallel/test-fs-access.js,, -parallel/test-fs-write-stream-double-close.js,, -parallel/test-fs-write-stream-throw-type-error.js,b/110226209, -parallel/test-fs-write-stream.js,, -parallel/test-http2-respond-file-error-pipe-offset.js,, -parallel/test-os.js,, -parallel/test-process-uid-gid.js,, -pseudo-tty/test-assert-colors.js,, -pseudo-tty/test-assert-no-color.js,, -pseudo-tty/test-assert-position-indicator.js,, -pseudo-tty/test-async-wrap-getasyncid-tty.js,, -pseudo-tty/test-fatal-error.js,, -pseudo-tty/test-handle-wrap-isrefed-tty.js,, -pseudo-tty/test-readable-tty-keepalive.js,, -pseudo-tty/test-set-raw-mode-reset-process-exit.js,, -pseudo-tty/test-set-raw-mode-reset-signal.js,, -pseudo-tty/test-set-raw-mode-reset.js,, -pseudo-tty/test-stderr-stdout-handle-sigwinch.js,, -pseudo-tty/test-stdout-read.js,, -pseudo-tty/test-tty-color-support.js,, -pseudo-tty/test-tty-isatty.js,, -pseudo-tty/test-tty-stdin-call-end.js,, -pseudo-tty/test-tty-stdin-end.js,, -pseudo-tty/test-stdin-write.js,, -pseudo-tty/test-tty-stdout-end.js,, -pseudo-tty/test-tty-stdout-resize.js,, -pseudo-tty/test-tty-stream-constructors.js,, -pseudo-tty/test-tty-window-size.js,, -pseudo-tty/test-tty-wrap.js,, -pummel/test-net-pingpong.js,, -pummel/test-vm-memleak.js,, diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude_php7.3.6.csv deleted file mode 100644 index 456bf7487..000000000 --- a/test/runtimes/exclude_php7.3.6.csv +++ /dev/null @@ -1,29 +0,0 @@ -test name,bug id,comment -ext/intl/tests/bug77895.phpt,, -ext/intl/tests/dateformat_bug65683_2.phpt,, -ext/mbstring/tests/bug76319.phpt,, -ext/mbstring/tests/bug76958.phpt,, -ext/mbstring/tests/bug77025.phpt,, -ext/mbstring/tests/bug77165.phpt,, -ext/mbstring/tests/bug77454.phpt,, -ext/mbstring/tests/mb_convert_encoding_leak.phpt,, -ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, -ext/standard/tests/file/filetype_variation.phpt,, -ext/standard/tests/file/fopen_variation19.phpt,, -ext/standard/tests/file/php_fd_wrapper_01.phpt,, -ext/standard/tests/file/php_fd_wrapper_02.phpt,, -ext/standard/tests/file/php_fd_wrapper_03.phpt,, -ext/standard/tests/file/php_fd_wrapper_04.phpt,, -ext/standard/tests/file/realpath_bug77484.phpt,, -ext/standard/tests/file/rename_variation.phpt,b/68717309, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,, -ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, -ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, -ext/standard/tests/network/bug20134.phpt,, -tests/output/stream_isatty_err.phpt,b/68720279, -tests/output/stream_isatty_in-err.phpt,b/68720282, -tests/output/stream_isatty_in-out-err.phpt,, -tests/output/stream_isatty_in-out.phpt,b/68720299, -tests/output/stream_isatty_out-err.phpt,b/68720311, -tests/output/stream_isatty_out.phpt,b/68720325, diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude_python3.7.3.csv deleted file mode 100644 index 2b9947212..000000000 --- a/test/runtimes/exclude_python3.7.3.csv +++ /dev/null @@ -1,27 +0,0 @@ -test name,bug id,comment -test_asynchat,b/76031995,SO_REUSEADDR -test_asyncio,,Fails on Docker. -test_asyncore,b/76031995,SO_REUSEADDR -test_epoll,, -test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. -test_ftplib,,Fails in Docker -test_httplib,b/76031995,SO_REUSEADDR -test_imaplib,, -test_logging,, -test_multiprocessing_fork,,Flaky. Sometimes times out. -test_multiprocessing_forkserver,,Flaky. Sometimes times out. -test_multiprocessing_main_handling,,Flaky. Sometimes times out. -test_multiprocessing_spawn,,Flaky. Sometimes times out. -test_nntplib,b/76031995,tests should not set SO_REUSEADDR -test_poplib,,Fails on Docker -test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted -test_pty,b/76157709,out of pty devices -test_readline,b/76157709,out of pty devices -test_resource,b/76174079, -test_selectors,b/76116849,OSError not raised with epoll -test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets -test_socket,b/75983380, -test_ssl,b/76031995,SO_REUSEADDR -test_subprocess,, -test_support,b/76031995,SO_REUSEADDR -test_telnetlib,b/76031995,SO_REUSEADDR diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD index f76e2ddc0..fdc6d3173 100644 --- a/test/runtimes/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -1,28 +1,11 @@ -load("//tools:defs.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) go_binary( name = "proctor", - srcs = [ - "go.go", - "java.go", - "nodejs.go", - "php.go", - "proctor.go", - "python.go", - ], + srcs = ["main.go"], pure = True, visibility = ["//test/runtimes:__pkg__"], -) - -go_test( - name = "proctor_test", - size = "small", - srcs = ["proctor_test.go"], - library = ":proctor", - pure = True, - deps = [ - "//pkg/test/testutil", - ], + deps = ["//test/runtimes/proctor/lib"], ) diff --git a/test/runtimes/proctor/lib/BUILD b/test/runtimes/proctor/lib/BUILD new file mode 100644 index 000000000..0c8367dfe --- /dev/null +++ b/test/runtimes/proctor/lib/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "lib", + srcs = [ + "go.go", + "java.go", + "lib.go", + "nodejs.go", + "php.go", + "python.go", + ], + visibility = ["//test/runtimes/proctor:__pkg__"], +) + +go_test( + name = "lib_test", + size = "small", + srcs = ["lib_test.go"], + library = ":lib", + deps = ["//pkg/test/testutil"], +) diff --git a/test/runtimes/proctor/go.go b/test/runtimes/proctor/lib/go.go index 3e2d5d8db..5c48fb60b 100644 --- a/test/runtimes/proctor/go.go +++ b/test/runtimes/proctor/lib/go.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" @@ -59,7 +59,7 @@ func (goRunner) ListTests() ([]string, error) { } // Go tests on disk. - diskSlice, err := search(goTestDir, goTestRegEx) + diskSlice, err := Search(goTestDir, goTestRegEx) if err != nil { return nil, err } @@ -74,17 +74,26 @@ func (goRunner) ListTests() ([]string, error) { return append(toolSlice, diskFiltered...), nil } -// TestCmd implements TestRunner.TestCmd. -func (goRunner) TestCmd(test string) *exec.Cmd { - // Check if test exists on disk by searching for file of the same name. - // This will determine whether or not it is a Go test on disk. - if strings.HasSuffix(test, ".go") { - // Test has suffix ".go" which indicates a disk test, run it as such. - cmd := exec.Command("go", "run", "run.go", "-v", "--", test) +// TestCmds implements TestRunner.TestCmds. +func (goRunner) TestCmds(tests []string) []*exec.Cmd { + var toolTests, onDiskTests []string + for _, test := range tests { + if strings.HasSuffix(test, ".go") { + onDiskTests = append(onDiskTests, test) + } else { + toolTests = append(toolTests, "^"+test+"$") + } + } + + var cmds []*exec.Cmd + if len(toolTests) > 0 { + cmds = append(cmds, exec.Command("go", "tool", "dist", "test", "-v", "-no-rebuild", "-run", strings.Join(toolTests, "\\|"))) + } + if len(onDiskTests) > 0 { + cmd := exec.Command("go", append([]string{"run", "run.go", "-v", "--"}, onDiskTests...)...) cmd.Dir = goTestDir - return cmd + cmds = append(cmds, cmd) } - // No ".go" suffix, run as a tool test. - return exec.Command("go", "tool", "dist", "test", "-run", test) + return cmds } diff --git a/test/runtimes/proctor/java.go b/test/runtimes/proctor/lib/java.go index 8b362029d..3105011ff 100644 --- a/test/runtimes/proctor/java.go +++ b/test/runtimes/proctor/lib/java.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" @@ -60,12 +60,17 @@ func (javaRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (javaRunner) TestCmd(test string) *exec.Cmd { - args := []string{ - "-noreport", - "-dir:" + javaTestDir, - test, - } - return exec.Command("jtreg", args...) +// TestCmds implements TestRunner.TestCmds. +func (javaRunner) TestCmds(tests []string) []*exec.Cmd { + args := append( + []string{ + "-agentvm", // Execute each action using a pool of reusable JVMs. + "-dir:" + javaTestDir, // Base directory for test files and directories. + "-noreport", // Do not generate a final report. + "-timeoutFactor:20", // Extend the default timeout (2 min) of all tests by this factor. + "-verbose:nopass", // Verbose output but supress it for tests that passed. + }, + tests..., + ) + return []*exec.Cmd{exec.Command("jtreg", args...)} } diff --git a/test/runtimes/proctor/proctor.go b/test/runtimes/proctor/lib/lib.go index b54abe434..f2ba82498 100644 --- a/test/runtimes/proctor/proctor.go +++ b/test/runtimes/proctor/lib/lib.go @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor runs the test for a particular runtime. It is meant to be -// included in Docker images for all runtime tests. -package main +// Package lib contains proctor functions. +package lib import ( - "flag" "fmt" - "log" "os" "os/exec" "os/signal" @@ -34,68 +31,15 @@ type TestRunner interface { // ListTests returns a string slice of tests available to run. ListTests() ([]string, error) - // TestCmd returns an *exec.Cmd that will run the given test. - TestCmd(test string) *exec.Cmd + // TestCmds returns a slice of *exec.Cmd that will run the given tests. + // There is no correlation between the number of exec.Cmds returned and the + // number of tests. It could return one command to run all tests or a few + // commands that collectively run all. + TestCmds(tests []string) []*exec.Cmd } -var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - testName = flag.String("test", "", "run a single test from the list of available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") -) - -func main() { - flag.Parse() - - if *pause { - pauseAndReap() - panic("pauseAndReap should never return") - } - - if *runtime == "" { - log.Fatalf("runtime flag must be provided") - } - - tr, err := testRunnerForRuntime(*runtime) - if err != nil { - log.Fatalf("%v", err) - } - - // List tests. - if *list { - tests, err := tr.ListTests() - if err != nil { - log.Fatalf("failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - - var tests []string - if *testName == "" { - // Run every test. - tests, err = tr.ListTests() - if err != nil { - log.Fatalf("failed to get all tests: %v", err) - } - } else { - // Run a single test. - tests = []string{*testName} - } - for _, test := range tests { - cmd := tr.TestCmd(test) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) - } - } -} - -// testRunnerForRuntime returns a new TestRunner for the given runtime. -func testRunnerForRuntime(runtime string) (TestRunner, error) { +// TestRunnerForRuntime returns a new TestRunner for the given runtime. +func TestRunnerForRuntime(runtime string) (TestRunner, error) { switch runtime { case "go": return goRunner{}, nil @@ -111,8 +55,8 @@ func testRunnerForRuntime(runtime string) (TestRunner, error) { return nil, fmt.Errorf("invalid runtime %q", runtime) } -// pauseAndReap is like init. It runs forever and reaps any children. -func pauseAndReap() { +// PauseAndReap is like init. It runs forever and reaps any children. +func PauseAndReap() { // Get notified of any new children. ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGCHLD) @@ -132,9 +76,9 @@ func pauseAndReap() { } } -// search is a helper function to find tests in the given directory that match +// Search is a helper function to find tests in the given directory that match // the regex. -func search(root string, testFilter *regexp.Regexp) ([]string, error) { +func Search(root string, testFilter *regexp.Regexp) ([]string, error) { var testSlice []string err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { diff --git a/test/runtimes/proctor/proctor_test.go b/test/runtimes/proctor/lib/lib_test.go index 6ef2de085..1193d2e28 100644 --- a/test/runtimes/proctor/proctor_test.go +++ b/test/runtimes/proctor/lib/lib_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "io/ioutil" @@ -47,7 +47,7 @@ func TestSearchEmptyDir(t *testing.T) { var want []string testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) + got, err := Search(td, testFilter) if err != nil { t.Errorf("search error: %v", err) } @@ -116,7 +116,7 @@ func TestSearch(t *testing.T) { } testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) + got, err := Search(td, testFilter) if err != nil { t.Errorf("search error: %v", err) } diff --git a/test/runtimes/proctor/nodejs.go b/test/runtimes/proctor/lib/nodejs.go index bd57db444..320597aa5 100644 --- a/test/runtimes/proctor/nodejs.go +++ b/test/runtimes/proctor/lib/nodejs.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "os/exec" @@ -32,15 +32,15 @@ var _ TestRunner = nodejsRunner{} // ListTests implements TestRunner.ListTests. func (nodejsRunner) ListTests() ([]string, error) { - testSlice, err := search(nodejsTestDir, nodejsTestRegEx) + testSlice, err := Search(nodejsTestDir, nodejsTestRegEx) if err != nil { return nil, err } return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (nodejsRunner) TestCmd(test string) *exec.Cmd { - args := []string{filepath.Join("tools", "test.py"), test} - return exec.Command("/usr/bin/python", args...) +// TestCmds implements TestRunner.TestCmds. +func (nodejsRunner) TestCmds(tests []string) []*exec.Cmd { + args := append([]string{filepath.Join("tools", "test.py"), "--timeout=180"}, tests...) + return []*exec.Cmd{exec.Command("/usr/bin/python", args...)} } diff --git a/test/runtimes/proctor/php.go b/test/runtimes/proctor/lib/php.go index 9115040e1..b67a60a97 100644 --- a/test/runtimes/proctor/php.go +++ b/test/runtimes/proctor/lib/php.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "os/exec" "regexp" + "strings" ) var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`) @@ -28,15 +29,15 @@ var _ TestRunner = phpRunner{} // ListTests implements TestRunner.ListTests. func (phpRunner) ListTests() ([]string, error) { - testSlice, err := search(".", phpTestRegEx) + testSlice, err := Search(".", phpTestRegEx) if err != nil { return nil, err } return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (phpRunner) TestCmd(test string) *exec.Cmd { - args := []string{"test", "TESTS=" + test} - return exec.Command("make", args...) +// TestCmds implements TestRunner.TestCmds. +func (phpRunner) TestCmds(tests []string) []*exec.Cmd { + args := []string{"test", "TESTS=" + strings.Join(tests, " ")} + return []*exec.Cmd{exec.Command("make", args...)} } diff --git a/test/runtimes/proctor/python.go b/test/runtimes/proctor/lib/python.go index b9e0fbe6f..429bfd850 100644 --- a/test/runtimes/proctor/python.go +++ b/test/runtimes/proctor/lib/python.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" @@ -42,8 +42,8 @@ func (pythonRunner) ListTests() ([]string, error) { return toolSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (pythonRunner) TestCmd(test string) *exec.Cmd { - args := []string{"-m", "test", test} - return exec.Command("./python", args...) +// TestCmds implements TestRunner.TestCmds. +func (pythonRunner) TestCmds(tests []string) []*exec.Cmd { + args := append([]string{"-m", "test"}, tests...) + return []*exec.Cmd{exec.Command("./python", args...)} } diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go new file mode 100644 index 000000000..e5607ac92 --- /dev/null +++ b/test/runtimes/proctor/main.go @@ -0,0 +1,85 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary proctor runs the test for a particular runtime. It is meant to be +// included in Docker images for all runtime tests. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "gvisor.dev/gvisor/test/runtimes/proctor/lib" +) + +var ( + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testNames = flag.String("tests", "", "run a subset of the available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") +) + +func main() { + flag.Parse() + + if *pause { + lib.PauseAndReap() + panic("pauseAndReap should never return") + } + + if *runtime == "" { + log.Fatalf("runtime flag must be provided") + } + + tr, err := lib.TestRunnerForRuntime(*runtime) + if err != nil { + log.Fatalf("%v", err) + } + + // List tests. + if *list { + tests, err := tr.ListTests() + if err != nil { + log.Fatalf("failed to list tests: %v", err) + } + for _, test := range tests { + fmt.Println(test) + } + return + } + + var tests []string + if *testNames == "" { + // Run every test. + tests, err = tr.ListTests() + if err != nil { + log.Fatalf("failed to get all tests: %v", err) + } + } else { + // Run subset of test. + tests = strings.Split(*testNames, ",") + } + + // Run tests. + cmds := tr.TestCmds(tests) + for _, cmd := range cmds { + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } + } +} diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD index 3972244b9..70cc01594 100644 --- a/test/runtimes/runner/BUILD +++ b/test/runtimes/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -7,15 +7,5 @@ go_binary( testonly = 1, srcs = ["main.go"], visibility = ["//test/runtimes:__pkg__"], - deps = [ - "//pkg/test/dockerutil", - "//pkg/test/testutil", - ], -) - -go_test( - name = "exclude_test", - size = "small", - srcs = ["exclude_test.go"], - library = ":runner", + deps = ["//test/runtimes/runner/lib"], ) diff --git a/test/runtimes/runner/lib/BUILD b/test/runtimes/runner/lib/BUILD new file mode 100644 index 000000000..d308f41b0 --- /dev/null +++ b/test/runtimes/runner/lib/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "lib", + testonly = 1, + srcs = ["lib.go"], + visibility = ["//test/runtimes/runner:__pkg__"], + deps = [ + "//pkg/log", + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) + +go_test( + name = "lib_test", + size = "small", + srcs = ["exclude_test.go"], + library = ":lib", +) diff --git a/test/runtimes/runner/exclude_test.go b/test/runtimes/runner/lib/exclude_test.go index c08755894..f996e895b 100644 --- a/test/runtimes/runner/exclude_test.go +++ b/test/runtimes/runner/lib/exclude_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "flag" @@ -20,14 +20,16 @@ import ( "testing" ) +var excludeFile = flag.String("exclude_file", "", "file to test (standard format)") + func TestMain(m *testing.M) { flag.Parse() os.Exit(m.Run()) } // Test that the exclude file parses without error. -func TestBlacklists(t *testing.T) { - ex, err := getExcludes() +func TestExcludelist(t *testing.T) { + ex, err := getExcludes(*excludeFile) if err != nil { t.Fatalf("error parsing exclude file: %v", err) } diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go new file mode 100644 index 000000000..78285cb0e --- /dev/null +++ b/test/runtimes/runner/lib/lib.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lib provides utilities for runner. +package lib + +import ( + "context" + "encoding/csv" + "fmt" + "io" + "os" + "sort" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// RunTests is a helper that is called by main. It exists so that we can run +// defered functions before exiting. It returns an exit code that should be +// passed to os.Exit. +func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int { + // Get tests to exclude.. + excludes, err := getExcludes(excludeFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) + return 1 + } + + // Construct the shared docker instance. + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(lang)) + defer d.CleanUp(ctx) + + if err := testutil.TouchShardStatusFile(); err != nil { + fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) + return 1 + } + + // Get a slice of tests to run. This will also start a single Docker + // container that will be used to run each test. The final test will + // stop the Docker container. + tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return 1 + } + + m := testing.MainStart(testDeps{}, tests, nil, nil) + return m.Run() +} + +// getTests executes all tests as table tests. +func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { + // Start the container. + opts := dockerutil.RunOpts{ + Image: fmt.Sprintf("runtimes/%s", image), + } + d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") + if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { + return nil, fmt.Errorf("docker run failed: %v", err) + } + + // Get a list of all tests in the image. + list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--list") + if err != nil { + return nil, fmt.Errorf("docker exec failed: %v", err) + } + + // Calculate a subset of tests to run corresponding to the current + // shard. + tests := strings.Fields(list) + sort.Strings(tests) + indices, err := testutil.TestIndicesForShard(len(tests)) + if err != nil { + return nil, fmt.Errorf("TestsForShard() failed: %v", err) + } + + var itests []testing.InternalTest + for i := 0; i < len(indices); i += batchSize { + var tcs []string + end := i + batchSize + if end > len(indices) { + end = len(indices) + } + for _, tc := range indices[i:end] { + // Add test if not excluded. + if _, ok := excludes[tests[tc]]; ok { + log.Infof("Skipping test case %s\n", tests[tc]) + continue + } + tcs = append(tcs, tests[tc]) + } + itests = append(itests, testing.InternalTest{ + Name: strings.Join(tcs, ", "), + F: func(t *testing.T) { + var ( + now = time.Now() + done = make(chan struct{}) + output string + err error + ) + + go func() { + fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ",")) + close(done) + }() + + select { + case <-done: + if err == nil { + fmt.Printf("PASS: (%v)\n\n", time.Since(now)) + return + } + t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) + case <-time.After(timeout): + t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) + } + }, + }) + } + + return itests, nil +} + +// getBlacklist reads the exclude file and returns a set of test names to +// exclude. +func getExcludes(excludeFile string) (map[string]struct{}, error) { + excludes := make(map[string]struct{}) + if excludeFile == "" { + return excludes, nil + } + f, err := os.Open(excludeFile) + if err != nil { + return nil, err + } + defer f.Close() + + r := csv.NewReader(f) + + // First line is header. Skip it. + if _, err := r.Read(); err != nil { + return nil, err + } + + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + excludes[record[0]] = struct{}{} + } + return excludes, nil +} + +// testDeps implements testing.testDeps (an unexported interface), and is +// required to use testing.MainStart. +type testDeps struct{} + +func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } +func (f testDeps) StartCPUProfile(io.Writer) error { return nil } +func (f testDeps) StopCPUProfile() {} +func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } +func (f testDeps) ImportPath() string { return "" } +func (f testDeps) StartTestLog(io.Writer) {} +func (f testDeps) StopTestLog() error { return nil } diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go index 54d1169ef..ec79a22c2 100644 --- a/test/runtimes/runner/main.go +++ b/test/runtimes/runner/main.go @@ -16,175 +16,27 @@ package main import ( - "encoding/csv" "flag" "fmt" - "io" "os" - "sort" - "strings" - "testing" "time" - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/test/runtimes/runner/lib" ) var ( lang = flag.String("lang", "", "language runtime to test") image = flag.String("image", "", "docker image with runtime tests") excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") + batchSize = flag.Int("batch", 50, "number of test cases run in one command") + timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") ) -// Wait time for each test to run. -const timeout = 5 * time.Minute - func main() { flag.Parse() if *lang == "" || *image == "" { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(runTests()) + os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout)) } - -// runTests is a helper that is called by main. It exists so that we can run -// defered functions before exiting. It returns an exit code that should be -// passed to os.Exit. -func runTests() int { - // Get tests to exclude.. - excludes, err := getExcludes() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) - return 1 - } - - // Construct the shared docker instance. - d := dockerutil.MakeDocker(testutil.DefaultLogger(*lang)) - defer d.CleanUp() - - // Get a slice of tests to run. This will also start a single Docker - // container that will be used to run each test. The final test will - // stop the Docker container. - tests, err := getTests(d, excludes) - if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) - return 1 - } - - m := testing.MainStart(testDeps{}, tests, nil, nil) - return m.Run() -} - -// getTests executes all tests as table tests. -func getTests(d *dockerutil.Docker, excludes map[string]struct{}) ([]testing.InternalTest, error) { - // Start the container. - opts := dockerutil.RunOpts{ - Image: fmt.Sprintf("runtimes/%s", *image), - } - d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") - if err := d.Spawn(opts, "/proctor/proctor", "--pause"); err != nil { - return nil, fmt.Errorf("docker run failed: %v", err) - } - - // Get a list of all tests in the image. - list, err := d.Exec(dockerutil.RunOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") - if err != nil { - return nil, fmt.Errorf("docker exec failed: %v", err) - } - - // Calculate a subset of tests to run corresponding to the current - // shard. - tests := strings.Fields(list) - sort.Strings(tests) - indices, err := testutil.TestIndicesForShard(len(tests)) - if err != nil { - return nil, fmt.Errorf("TestsForShard() failed: %v", err) - } - - var itests []testing.InternalTest - for _, tci := range indices { - // Capture tc in this scope. - tc := tests[tci] - itests = append(itests, testing.InternalTest{ - Name: tc, - F: func(t *testing.T) { - // Is the test excluded? - if _, ok := excludes[tc]; ok { - t.Skipf("SKIP: excluded test %q", tc) - } - - var ( - now = time.Now() - done = make(chan struct{}) - output string - err error - ) - - go func() { - fmt.Printf("RUNNING %s...\n", tc) - output, err = d.Exec(dockerutil.RunOpts{}, "/proctor/proctor", "--runtime", *lang, "--test", tc) - close(done) - }() - - select { - case <-done: - if err == nil { - fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now)) - return - } - t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output) - case <-time.After(timeout): - t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output) - } - }, - }) - } - - return itests, nil -} - -// getBlacklist reads the exclude file and returns a set of test names to -// exclude. -func getExcludes() (map[string]struct{}, error) { - excludes := make(map[string]struct{}) - if *excludeFile == "" { - return excludes, nil - } - f, err := os.Open(*excludeFile) - if err != nil { - return nil, err - } - defer f.Close() - - r := csv.NewReader(f) - - // First line is header. Skip it. - if _, err := r.Read(); err != nil { - return nil, err - } - - for { - record, err := r.Read() - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - excludes[record[0]] = struct{}{} - } - return excludes, nil -} - -// testDeps implements testing.testDeps (an unexported interface), and is -// required to use testing.MainStart. -type testDeps struct{} - -func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } -func (f testDeps) StartCPUProfile(io.Writer) error { return nil } -func (f testDeps) StopCPUProfile() {} -func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } -func (f testDeps) ImportPath() string { return "" } -func (f testDeps) StartTestLog(io.Writer) {} -func (f testDeps) StopTestLog() error { return nil } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 7c4cd8192..96a775456 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -4,97 +4,84 @@ package(licenses = ["notice"]) syscall_test( test = "//test/syscalls/linux:32bit_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:accept_bind_stream_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:access_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:affinity_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:aio_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:alarm_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:arch_prctl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:bad_test", - vfs2 = "True", ) syscall_test( size = "large", add_overlay = True, test = "//test/syscalls/linux:bind_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:brk_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_capability_test", - vfs2 = "True", ) syscall_test( size = "large", + # Produce too many logs in the debug mode. + debug = False, shard_count = 50, # Takes too long for TSAN. Since this is kind of a stress test that doesn't # involve much concurrency, TSAN's usefulness here is limited anyway. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_stress_test", - vfs2 = "True", + vfs2 = False, ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chdir_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chmod_test", - vfs2 = "True", ) syscall_test( @@ -102,115 +89,96 @@ syscall_test( add_overlay = True, test = "//test/syscalls/linux:chown_test", use_tmpfs = True, # chwon tests require gofer to be running as root. - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chroot_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:clock_getres_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:clock_gettime_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:clock_nanosleep_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:concurrency_test", - vfs2 = "True", ) syscall_test( add_uds_tree = True, test = "//test/syscalls/linux:connect_external_test", use_tmpfs = True, - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:creat_test", - vfs2 = "True", ) syscall_test( + fuse = "True", test = "//test/syscalls/linux:dev_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:dup_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:epoll_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:eventfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:exceptions_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_binary_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:exit_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fadvise64_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fallocate_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fault_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fchdir_test", - vfs2 = "True", ) syscall_test( @@ -222,66 +190,55 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:flock_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fork_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fpsig_fork_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fpsig_nested_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fsync_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:futex_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getcpu_host_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getcpu_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:getdents_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getrandom_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getrusage_test", - vfs2 = "True", ) syscall_test( size = "medium", - add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed. + add_overlay = True, test = "//test/syscalls/linux:inotify_test", ) @@ -289,63 +246,60 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:ioctl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:iptables_test", - vfs2 = "True", +) + +syscall_test( + test = "//test/syscalls/linux:ip6tables_test", ) syscall_test( size = "large", shard_count = 5, test = "//test/syscalls/linux:itimer_test", - vfs2 = "True", +) + +syscall_test( + test = "//test/syscalls/linux:kcov_test", ) syscall_test( test = "//test/syscalls/linux:kill_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:link_test", use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2) - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:lseek_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:madvise_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:memory_accounting_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:mempolicy_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:mincore_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mkdir_test", - vfs2 = "True", ) syscall_test( @@ -357,41 +311,34 @@ syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:mmap_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mount_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:mremap_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:msync_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:munmap_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:network_namespace_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_create_test", - vfs2 = "True", ) syscall_test( @@ -401,22 +348,18 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:packet_socket_raw_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:packet_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:partial_bad_buffer_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:pause_test", - vfs2 = "True", ) syscall_test( @@ -424,7 +367,6 @@ syscall_test( # Takes too long under gotsan to run. tags = ["nogotsan"], test = "//test/syscalls/linux:ping_socket_test", - vfs2 = "True", ) syscall_test( @@ -432,206 +374,169 @@ syscall_test( add_overlay = True, shard_count = 5, test = "//test/syscalls/linux:pipe_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:poll_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:ppoll_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:prctl_setuid_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:prctl_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pread64_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv2_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:priority_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:proc_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_oomscore_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_smaps_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_uid_gid_map_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:pselect_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:ptrace_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:pty_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:pty_root_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwritev2_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwrite64_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_hdrincl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_icmp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:read_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:readahead_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:readv_socket_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:readv_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:rename_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rlimits_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rseq_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rtsignal_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:signalfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sched_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sched_yield_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:seccomp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:select_test", - vfs2 = "True", ) syscall_test( shard_count = 20, test = "//test/syscalls/linux:semaphore_test", - vfs2 = "True", ) syscall_test( @@ -647,12 +552,10 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:splice_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigaction_test", - vfs2 = "True", ) # TODO(b/119826902): Enable once the test passes in runsc. @@ -660,62 +563,52 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:sigiret_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigprocmask_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:sigstop_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigtimedwait_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:shm_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_abstract_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_domain_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:socket_filesystem_non_blocking_test", - vfs2 = "True", ) syscall_test( @@ -723,14 +616,12 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", - vfs2 = "True", ) syscall_test( @@ -739,122 +630,116 @@ syscall_test( # Takes too long for TSAN. Creates a lot of TCP sockets. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", - vfs2 = "True", +) + +syscall_test( + size = "medium", + # Takes too long under gotsan to run. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_nogotsan_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_netlink_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_ipv6_udp_unbound_loopback_netlink_test", ) syscall_test( test = "//test/syscalls/linux:socket_ip_unbound_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netdevice_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_route_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_uevent_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_blocking_local_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_blocking_ip_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_local_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test", - vfs2 = "True", ) syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_local_test", - vfs2 = "True", ) syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_tcp_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_nonblock_local_test", - vfs2 = "True", ) syscall_test( @@ -862,13 +747,11 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_dgram_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_dgram_non_blocking_test", - vfs2 = "True", ) syscall_test( @@ -876,7 +759,6 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", - vfs2 = "True", ) syscall_test( @@ -884,134 +766,112 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_stream_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_abstract_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_dgram_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_filesystem_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 10, test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:statfs_test", - vfs2 = "True", + use_tmpfs = True, # Test specifically relies on TEST_TMPDIR to be tmpfs. ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:stat_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:stat_times_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sticky_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:symlink_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_file_range_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sysinfo_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:syslog_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sysret_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 10, test = "//test/syscalls/linux:tcp_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tgkill_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:timerfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:timers_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:time_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tkill_test", - vfs2 = "True", ) syscall_test( @@ -1021,18 +881,15 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:tuntap_test", - vfs2 = "True", ) syscall_test( add_hostinet = True, test = "//test/syscalls/linux:tuntap_hostinet_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:udp_bind_test", - vfs2 = "True", ) syscall_test( @@ -1040,80 +897,65 @@ syscall_test( add_hostinet = True, shard_count = 10, test = "//test/syscalls/linux:udp_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:uidgid_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:uname_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:unlink_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:unshare_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:utimes_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:vdso_clock_gettime_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vdso_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vsyscall_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vfork_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:wait_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:write_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_unix_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_tcp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_udp_test", - vfs2 = "True", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 9e097c888..6a2ec9787 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -22,6 +22,7 @@ exports_files( "socket_ipv4_tcp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_loopback.cc", + "socket_ipv4_udp_unbound_loopback_nogotsan.cc", "tcp_socket.cc", "udp_bind.cc", "udp_socket.cc", @@ -943,6 +944,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", gtest, "//test/util:posix_error", @@ -1029,6 +1031,24 @@ cc_binary( ) cc_binary( + name = "ip6tables_test", + testonly = 1, + srcs = [ + "ip6tables.cc", + ], + linkstatic = 1, + deps = [ + ":iptables_types", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "itimer_test", testonly = 1, srcs = ["itimer.cc"], @@ -1049,6 +1069,21 @@ cc_binary( ) cc_binary( + name = "kcov_test", + testonly = 1, + srcs = ["kcov.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( name = "kill_test", testonly = 1, srcs = ["kill.cc"], @@ -1330,6 +1365,7 @@ cc_binary( name = "packet_socket_raw_test", testonly = 1, srcs = ["packet_socket_raw.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -1632,12 +1668,14 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", gtest, "//test/util:memory_util", "//test/util:posix_error", + "//test/util:proc_util", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", @@ -1809,6 +1847,7 @@ cc_binary( name = "raw_socket_test", testonly = 1, srcs = ["raw_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -1859,6 +1898,7 @@ cc_binary( srcs = ["readahead.cc"], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:file_descriptor", gtest, "//test/util:temp_path", @@ -1950,6 +1990,7 @@ cc_binary( gtest, "//test/util:logging", "//test/util:multiprocess_util", + "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", ], @@ -2374,12 +2415,50 @@ cc_library( ":socket_test_util", "@com_google_absl//absl/memory", gtest, + "//test/util:posix_error", "//test/util:test_util", ], alwayslink = 1, ) cc_library( + name = "socket_ipv4_udp_unbound_netlink_test_cases", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_netlink.cc", + ], + hdrs = [ + "socket_ipv4_udp_unbound_netlink.h", + ], + deps = [ + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:cleanup", + gtest, + ], + alwayslink = 1, +) + +cc_library( + name = "socket_ipv6_udp_unbound_netlink_test_cases", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_netlink.cc", + ], + hdrs = [ + "socket_ipv6_udp_unbound_netlink.h", + ], + deps = [ + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + gtest, + ], + alwayslink = 1, +) + +cc_library( name = "socket_ipv4_udp_unbound_external_networking_test_cases", testonly = 1, srcs = [ @@ -2716,6 +2795,55 @@ cc_binary( ) cc_binary( + name = "socket_ipv4_udp_unbound_loopback_nogotsan_test", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_loopback_nogotsan.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/memory", + ], +) + +cc_binary( + name = "socket_ipv4_udp_unbound_loopback_netlink_test", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_loopback_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv4_udp_unbound_netlink_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "socket_ipv6_udp_unbound_loopback_netlink_test", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_loopback_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv6_udp_unbound_netlink_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "socket_ip_unbound_test", testonly = 1, srcs = [ @@ -3407,6 +3535,7 @@ cc_binary( name = "tcp_socket_test", testonly = 1, srcs = ["tcp_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -3543,15 +3672,12 @@ cc_binary( ], ) -cc_library( - name = "udp_socket_test_cases", +cc_binary( + name = "udp_socket_test", testonly = 1, - srcs = [ - "udp_socket_errqueue_test_case.cc", - "udp_socket_test_cases.cc", - ], - hdrs = ["udp_socket_test_cases.h"], + srcs = ["udp_socket.cc"], defines = select_system(), + linkstatic = 1, deps = [ ":ip_socket_test_util", ":socket_test_util", @@ -3566,17 +3692,6 @@ cc_library( "//test/util:test_util", "//test/util:thread_util", ], - alwayslink = 1, -) - -cc_binary( - name = "udp_socket_test", - testonly = 1, - srcs = ["udp_socket.cc"], - linkstatic = 1, - deps = [ - ":udp_socket_test_cases", - ], ) cc_binary( diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 3c88c4cbd..1d0d584cd 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -156,11 +156,24 @@ TEST(DevTest, TTYExists) { TEST(DevTest, OpenDevFuse) { // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new // device registration is complete. - SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor()); + SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor() || !IsFUSEEnabled()); ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY)); } +TEST(DevTest, ReadDevFuseWithoutMount) { + // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new + // device registration is complete. + SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor()); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY)); + + std::vector<char> buf(1); + EXPECT_THAT(ReadFd(fd.get(), buf.data(), sizeof(buf)), + SyscallFailsWithErrno(EPERM)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index f57d38dc7..2101e5c9f 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -422,6 +422,28 @@ TEST(EpollTest, CloseFile) { SyscallSucceedsWithValue(0)); } +TEST(EpollTest, PipeReaderHupAfterWriterClosed) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + int pipefds[2]; + ASSERT_THAT(pipe(pipefds), SyscallSucceeds()); + FileDescriptor rfd(pipefds[0]); + FileDescriptor wfd(pipefds[1]); + + ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), rfd.get(), 0, kMagicConstant)); + struct epoll_event result[kFDsPerEpoll]; + // Initially, rfd should not generate any events of interest. + ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0), + SyscallSucceedsWithValue(0)); + // Close the write end of the pipe. + wfd.reset(); + // rfd should now generate EPOLLHUP, which EPOLL_CTL_ADD unconditionally adds + // to the set of events of interest. + ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0), + SyscallSucceedsWithValue(1)); + EXPECT_EQ(result[0].events, EPOLLHUP); + EXPECT_EQ(result[0].data.u64, kMagicConstant); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index e09afafe9..c5acfc794 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -553,7 +553,12 @@ TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { // Hold onto TempPath objects so they are not destructed prematurely. std::vector<TempPath> interpreter_symlinks; std::vector<TempPath> script_symlinks; - for (int i = 0; i < kLinuxMaxSymlinks; i++) { + // Replace both the interpreter and script paths with symlink chains of just + // over half the symlink limit each; this is the minimum required to test that + // the symlink limit applies separately to each traversal, while tolerating + // some symlinks in the resolution of (the original) interpreter_path and + // script_path. + for (int i = 0; i < (kLinuxMaxSymlinks / 2) + 1; i++) { interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); interpreter_path = interpreter_symlinks[i].path(); @@ -679,18 +684,16 @@ TEST(ExecveatTest, UnshareFiles) { const FileDescriptor fd_closed_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - pid_t child; - EXPECT_THAT(child = syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, - 0, 0, 0, 0), - SyscallSucceeds()); + ExecveArray argv = {"test"}; + ExecveArray envp; + std::string child_path = RunfilePath(kBasicWorkload); + pid_t child = + syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, 0, 0, 0, 0); if (child == 0) { - ExecveArray argv = {"test"}; - ExecveArray envp; - ASSERT_THAT( - execve(RunfilePath(kBasicWorkload).c_str(), argv.get(), envp.get()), - SyscallSucceeds()); + execve(child_path.c_str(), argv.get(), envp.get()); _exit(1); } + ASSERT_THAT(child, SyscallSucceeds()); int status; ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 18d2f22c1..3797fd4c8 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -1042,6 +1042,13 @@ class ElfInterpreterStaticTest // Statically linked ELF with a statically linked ELF interpreter. TEST_P(ElfInterpreterStaticTest, Test) { + // TODO(gvisor.dev/issue/3721): Test has been observed to segfault on 5.X + // kernels. + if (!IsRunningOnGvisor()) { + auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); + SKIP_IF(version.major > 4); + } + const std::vector<char> segment_suffix = std::get<0>(GetParam()); const int expected_errno = std::get<1>(GetParam()); diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index cabc2b751..edd23e063 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -179,6 +179,12 @@ TEST_F(AllocateTest, FallocateOtherFDs) { auto sock0 = FileDescriptor(socks[0]); auto sock1 = FileDescriptor(socks[1]); EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + int pipefds[2]; + ASSERT_THAT(pipe(pipefds), SyscallSucceeds()); + EXPECT_THAT(fallocate(pipefds[1], 0, 0, 10), SyscallFailsWithErrno(ESPIPE)); + close(pipefds[0]); + close(pipefds[1]); } } // namespace diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 5467fa2c8..34016d4bd 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -1004,7 +1004,8 @@ TEST(FcntlTest, SetOwnPid) { pid_t pid; EXPECT_THAT(pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(pid)); @@ -1018,7 +1019,8 @@ TEST(FcntlTest, SetOwnPgrp) { pid_t pgid; EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), + SyscallSucceedsWithValue(0)); // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the // negative return value as an error, converting the return value to -1 and @@ -1038,8 +1040,10 @@ TEST(FcntlTest, SetOwnUnset) { // Set and unset pid. pid_t pid; EXPECT_THAT(pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(0)); @@ -1047,8 +1051,10 @@ TEST(FcntlTest, SetOwnUnset) { // Set and unset pgid. pid_t pgid; EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(0)); @@ -1120,7 +1126,7 @@ TEST(FcntlTest, SetOwnExTid) { EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(owner.pid)); @@ -1136,7 +1142,7 @@ TEST(FcntlTest, SetOwnExPid) { EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(owner.pid)); @@ -1152,7 +1158,7 @@ TEST(FcntlTest, SetOwnExPgrp) { EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the // negative return value as an error, converting the return value to -1 and @@ -1176,10 +1182,10 @@ TEST(FcntlTest, SetOwnExUnset) { owner.type = F_OWNER_PID; EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); owner.pid = 0; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(0)); @@ -1188,10 +1194,10 @@ TEST(FcntlTest, SetOwnExUnset) { owner.type = F_OWNER_PGRP; EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); owner.pid = 0; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), SyscallSucceedsWithValue(0)); @@ -1207,7 +1213,7 @@ TEST(FcntlTest, GetOwnExTid) { EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); f_owner_ex got_owner = {}; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), @@ -1225,7 +1231,7 @@ TEST(FcntlTest, GetOwnExPid) { EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); f_owner_ex got_owner = {}; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), @@ -1243,7 +1249,7 @@ TEST(FcntlTest, GetOwnExPgrp) { EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); f_owner_ex got_owner = {}; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index 638a93979..549141cbb 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -185,7 +185,7 @@ TEST_F(FlockTest, TestMultipleHolderSharedExclusive) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderNonblocking) { // This test will verify that a shared lock is denied while // someone holds an exclusive lock. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), @@ -203,7 +203,33 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { +void trivial_handler(int signum) {} + +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) { + const DisableSave ds; // Timing-related. + + // This test will verify that a shared lock is denied while + // someone holds an exclusive lock. + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), + SyscallSucceedsWithValue(0)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); + + // Register a signal handler for SIGALRM and set an alarm that will go off + // while blocking in the subsequent flock() call. This will interrupt flock() + // and cause it to return EINTR. + struct sigaction act = {}; + act.sa_handler = trivial_handler; + ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); + ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallFailsWithErrno(EINTR)); + + // Unlock + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); +} + +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderNonblocking) { // This test will verify that an exclusive lock is denied while // someone already holds an exclsuive lock. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), @@ -221,6 +247,30 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) { + const DisableSave ds; // Timing-related. + + // This test will verify that an exclusive lock is denied while + // someone already holds an exclsuive lock. + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), + SyscallSucceedsWithValue(0)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); + + // Register a signal handler for SIGALRM and set an alarm that will go off + // while blocking in the subsequent flock() call. This will interrupt flock() + // and cause it to return EINTR. + struct sigaction act = {}; + act.sa_handler = trivial_handler; + ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); + ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + ASSERT_THAT(flock(fd.get(), LOCK_EX), SyscallFailsWithErrno(EINTR)); + + // Unlock + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); +} + TEST_F(FlockTest, TestMultipleHolderSharedExclusiveUpgrade) { // This test will verify that we cannot obtain an exclusive lock while // a shared lock is held by another descriptor, then verify that an upgrade diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index 40c80a6e1..90b1f0508 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -18,6 +18,7 @@ #include <sys/syscall.h> #include <sys/time.h> #include <sys/types.h> +#include <syscall.h> #include <unistd.h> #include <algorithm> @@ -737,6 +738,97 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { } } +int get_robust_list(int pid, struct robust_list_head** head_ptr, + size_t* len_ptr) { + return syscall(__NR_get_robust_list, pid, head_ptr, len_ptr); +} + +int set_robust_list(struct robust_list_head* head, size_t len) { + return syscall(__NR_set_robust_list, head, len); +} + +TEST(RobustFutexTest, BasicSetGet) { + struct robust_list_head hd = {}; + struct robust_list_head* hd_ptr = &hd; + + // Set! + EXPECT_THAT(set_robust_list(hd_ptr, sizeof(hd)), SyscallSucceedsWithValue(0)); + + // Get! + struct robust_list_head* new_hd_ptr = hd_ptr; + size_t len; + EXPECT_THAT(get_robust_list(0, &new_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(new_hd_ptr, hd_ptr); + EXPECT_EQ(len, sizeof(hd)); +} + +TEST(RobustFutexTest, GetFromOtherTid) { + // Get the current tid and list head. + pid_t tid = gettid(); + struct robust_list_head* hd_ptr = {}; + size_t len; + EXPECT_THAT(get_robust_list(0, &hd_ptr, &len), SyscallSucceedsWithValue(0)); + + // Create a new thread. + ScopedThread t([&] { + // Current tid list head should be different from parent tid. + struct robust_list_head* got_hd_ptr = {}; + EXPECT_THAT(get_robust_list(0, &got_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_NE(hd_ptr, got_hd_ptr); + + // Get the parent list head by passing its tid. + EXPECT_THAT(get_robust_list(tid, &got_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(hd_ptr, got_hd_ptr); + }); + + // Wait for thread. + t.Join(); +} + +TEST(RobustFutexTest, InvalidSize) { + struct robust_list_head* hd = {}; + EXPECT_THAT(set_robust_list(hd, sizeof(*hd) + 1), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(RobustFutexTest, PthreadMutexAttr) { + constexpr int kNumMutexes = 3; + + // Create a bunch of robust mutexes. + pthread_mutexattr_t attrs[kNumMutexes]; + pthread_mutex_t mtxs[kNumMutexes]; + for (int i = 0; i < kNumMutexes; i++) { + TEST_PCHECK(pthread_mutexattr_init(&attrs[i]) == 0); + TEST_PCHECK(pthread_mutexattr_setrobust(&attrs[i], PTHREAD_MUTEX_ROBUST) == + 0); + TEST_PCHECK(pthread_mutex_init(&mtxs[i], &attrs[i]) == 0); + } + + // Start thread to lock the mutexes and then exit. + ScopedThread t([&] { + for (int i = 0; i < kNumMutexes; i++) { + TEST_PCHECK(pthread_mutex_lock(&mtxs[i]) == 0); + } + pthread_exit(NULL); + }); + + // Wait for thread. + t.Join(); + + // Now try to take the mutexes. + for (int i = 0; i < kNumMutexes; i++) { + // Should get EOWNERDEAD. + EXPECT_EQ(pthread_mutex_lock(&mtxs[i]), EOWNERDEAD); + // Make the mutex consistent. + EXPECT_EQ(pthread_mutex_consistent(&mtxs[i]), 0); + // Unlock. + EXPECT_EQ(pthread_mutex_unlock(&mtxs[i]), 0); + } +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index b147d6181..b040cdcf7 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -32,6 +32,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "test/util/eventfd_util.h" @@ -393,7 +394,7 @@ TYPED_TEST(GetdentsTest, ProcSelfFd) { // Make the buffer very small since we want to iterate. typename TestFixture::DirentBufferType dirents( 2 * sizeof(typename TestFixture::LinuxDirentType)); - std::unordered_set<int> prev_fds; + absl::node_hash_set<int> prev_fds; while (true) { dirents.Reset(); int rv; diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index 220874aeb..e4392a450 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -18,6 +18,7 @@ #include <sys/epoll.h> #include <sys/inotify.h> #include <sys/ioctl.h> +#include <sys/sendfile.h> #include <sys/time.h> #include <sys/xattr.h> @@ -464,7 +465,9 @@ TEST(Inotify, ConcurrentFileDeletionAndWatchRemoval) { for (int i = 0; i < 100; ++i) { FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT, S_IRUSR | S_IWUSR)); - file_fd.reset(); // Close before unlinking (although save is disabled). + // Close before unlinking (although S/R is disabled). Some filesystems + // cannot restore an open fd on an unlinked file. + file_fd.reset(); EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); } }; @@ -1255,10 +1258,7 @@ TEST(Inotify, MknodGeneratesCreateEvent) { InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); const TempPath file1(root.path() + "/file1"); - const int rc = mknod(file1.path().c_str(), S_IFREG, 0); - // mknod(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0); - ASSERT_THAT(rc, SyscallSucceeds()); + ASSERT_THAT(mknod(file1.path().c_str(), S_IFREG, 0), SyscallSucceeds()); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); @@ -1288,6 +1288,10 @@ TEST(Inotify, SymlinkGeneratesCreateEvent) { } TEST(Inotify, LinkGeneratesAttribAndCreateEvents) { + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); @@ -1300,11 +1304,8 @@ TEST(Inotify, LinkGeneratesAttribAndCreateEvents) { const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - const int rc = link(file1.path().c_str(), link1.path().c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); + ASSERT_THAT(link(file1.path().c_str(), link1.path().c_str()), + SyscallSucceeds()); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); @@ -1333,66 +1334,70 @@ TEST(Inotify, UtimesGeneratesAttribEvent) { } TEST(Inotify, HardlinksReuseSameWatch) { + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file1 = + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - TempPath link1(root.path() + "/link1"); - const int rc = link(file1.path().c_str(), link1.path().c_str()); - // link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); + + TempPath file2(root.path() + "/file2"); + ASSERT_THAT(link(file.path().c_str(), file2.path().c_str()), + SyscallSucceeds()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - const int link1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), link1.path(), IN_ALL_EVENTS)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file.path(), IN_ALL_EVENTS)); + const int file2_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file2.path(), IN_ALL_EVENTS)); // The watch descriptors for watches on different links to the same file // should be identical. - EXPECT_NE(root_wd, file1_wd); - EXPECT_EQ(file1_wd, link1_wd); + EXPECT_NE(root_wd, file_wd); + EXPECT_EQ(file_wd, file2_wd); - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); + FileDescriptor file_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); ASSERT_THAT(events, - AreUnordered({Event(IN_OPEN, root_wd, Basename(file1.path())), - Event(IN_OPEN, file1_wd)})); + AreUnordered({Event(IN_OPEN, root_wd, Basename(file.path())), + Event(IN_OPEN, file_wd)})); // For the next step, we want to ensure all fds to the file are closed. Do // that now and drain the resulting events. - file1_fd.reset(); + file_fd.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, - Are({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())), - Event(IN_CLOSE_WRITE, file1_wd)})); + ASSERT_THAT( + events, + AreUnordered({Event(IN_CLOSE_WRITE, root_wd, Basename(file.path())), + Event(IN_CLOSE_WRITE, file_wd)})); // Try removing the link and let's see what events show up. Note that after // this, we still have a link to the file so the watch shouldn't be // automatically removed. - const std::string link1_path = link1.reset(); + const std::string file2_path = file2.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, link1_wd), - Event(IN_DELETE, root_wd, Basename(link1_path))})); + ASSERT_THAT(events, + AreUnordered({Event(IN_ATTRIB, file2_wd), + Event(IN_DELETE, root_wd, Basename(file2_path))})); // Now remove the other link. Since this is the last link to the file, the // watch should be automatically removed. - const std::string file1_path = file1.reset(); + const std::string file_path = file.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); ASSERT_THAT( events, - AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd), - Event(IN_IGNORED, file1_wd), - Event(IN_DELETE, root_wd, Basename(file1_path))})); + AreUnordered({Event(IN_ATTRIB, file_wd), Event(IN_DELETE_SELF, file_wd), + Event(IN_IGNORED, file_wd), + Event(IN_DELETE, root_wd, Basename(file_path))})); } // Calling mkdir within "parent/child" should generate an event for child, but @@ -1681,6 +1686,60 @@ TEST(Inotify, EpollNoDeadlock) { } } +TEST(Inotify, Fallocate) { + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); + + // Do an arbitrary modification with fallocate. + ASSERT_THAT(RetryEINTR(fallocate)(fd.get(), 0, 0, 123), SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_MODIFY, wd)})); +} + +TEST(Inotify, Sendfile) { + SKIP_IF(IsRunningWithVFS1()); + + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(root.path(), "x", 0644)); + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + const FileDescriptor out = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); + + // Create separate inotify instances for the in and out fds. If both watches + // were on the same instance, we would have discrepancies between Linux and + // gVisor (order of events, duplicate events), which is not that important + // since inotify is asynchronous anyway. + const FileDescriptor in_inotify = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const FileDescriptor out_inotify = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int in_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(in_inotify.get(), in_file.path(), IN_ALL_EVENTS)); + const int out_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(out_inotify.get(), out_file.path(), IN_ALL_EVENTS)); + + ASSERT_THAT(sendfile(out.get(), in.get(), /*offset=*/nullptr, 1), + SyscallSucceeds()); + + // Expect a single access event and a single modify event. + std::vector<Event> in_events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(in_inotify.get())); + std::vector<Event> out_events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(out_inotify.get())); + EXPECT_THAT(in_events, Are({Event(IN_ACCESS, in_wd)})); + EXPECT_THAT(out_events, Are({Event(IN_MODIFY, out_wd)})); +} + // On Linux, inotify behavior is not very consistent with splice(2). We try our // best to emulate Linux for very basic calls to splice. TEST(Inotify, SpliceOnWatchTarget) { @@ -1749,17 +1808,17 @@ TEST(Inotify, SpliceOnInotifyFD) { // Watches on a parent should not be triggered by actions on a hard link to one // of its children that has a different parent. TEST(Inotify, LinkOnOtherParent) { + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); std::string link_path = NewTempAbsPathInDir(dir2.path()); - const int rc = link(file.path().c_str(), link_path.c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); + ASSERT_THAT(link(file.path().c_str(), link_path.c_str()), SyscallSucceeds()); const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); @@ -1768,13 +1827,18 @@ TEST(Inotify, LinkOnOtherParent) { // Perform various actions on the link outside of dir1, which should trigger // no inotify events. - const FileDescriptor fd = + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(link_path.c_str(), O_RDWR)); int val = 0; ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); + + // Close before unlinking; some filesystems cannot restore an open fd on an + // unlinked file. + fd.reset(); ASSERT_THAT(unlink(link_path.c_str()), SyscallSucceeds()); + const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); EXPECT_THAT(events, Are({})); @@ -1879,14 +1943,22 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ - Event(IN_ATTRIB, file_wd), - Event(IN_DELETE, dir_wd, Basename(file.path())), - Event(IN_ACCESS, dir_wd, Basename(file.path())), - Event(IN_ACCESS, file_wd), - Event(IN_MODIFY, dir_wd, Basename(file.path())), - Event(IN_MODIFY, file_wd), - })); + EXPECT_THAT(events, AnyOf(Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }), + Are({ + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }))); fd.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); @@ -1929,7 +2001,7 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) { ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ + EXPECT_THAT(events, AreUnordered({ Event(IN_ATTRIB, file_wd), Event(IN_DELETE, dir_wd, Basename(file.path())), })); @@ -1990,21 +2062,21 @@ TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { - const DisableSave ds; + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); // TODO(gvisor.dev/issue/1624): This test fails on VFS1. SKIP_IF(IsRunningWithVFS1()); + const DisableSave ds; + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); std::string path1 = file.path(); std::string path2 = NewTempAbsPathInDir(dir.path()); + ASSERT_THAT(link(path1.c_str(), path2.c_str()), SyscallSucceeds()); - const int rc = link(path1.c_str(), path2.c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); const FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(Open(path1.c_str(), O_RDWR)); const FileDescriptor fd2 = @@ -2036,6 +2108,15 @@ TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { + // TODO(gvisor.dev/issue/1624): Fails on VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // NOTE(gvisor.dev/issue/3654): In the gofer filesystem, we do not allow + // setting attributes through an fd if the file at the open path has been + // deleted. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const DisableSave ds; const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -2045,18 +2126,6 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR)); - // NOTE(b/157163751): Create another link before unlinking. This is needed for - // the gofer filesystem in gVisor, where open fds will not work once the link - // count hits zero. In VFS2, we end up skipping the gofer test anyway, because - // hard links are not supported for gofer fs. - if (IsRunningOnGvisor()) { - std::string link_path = NewTempAbsPath(); - const int rc = link(file.path().c_str(), link_path.c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(rc != 0 && (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); - } - const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( @@ -2072,12 +2141,18 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ - Event(IN_ATTRIB, file_wd), - Event(IN_DELETE, dir_wd, Basename(file.path())), - Event(IN_MODIFY, dir_wd, Basename(file.path())), - Event(IN_MODIFY, file_wd), - })); + EXPECT_THAT(events, AnyOf(Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }), + Are({ + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }))); const struct timeval times[2] = {{1, 0}, {2, 0}}; ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds()); diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc new file mode 100644 index 000000000..de0a1c114 --- /dev/null +++ b/test/syscalls/linux/ip6tables.cc @@ -0,0 +1,233 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/capability.h> +#include <sys/socket.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/iptables.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kNatTablename[] = "nat"; +constexpr char kErrorTarget[] = "ERROR"; +constexpr size_t kEmptyStandardEntrySize = + sizeof(struct ip6t_entry) + sizeof(struct xt_standard_target); +constexpr size_t kEmptyErrorEntrySize = + sizeof(struct ip6t_entry) + sizeof(struct xt_error_target); + +TEST(IP6TablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // ip6tables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + EXPECT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IP6TablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IP6TablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ip6t_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ip6t_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IP6TablesBasic, GetRevision) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + struct xt_get_revision rev = { + .name = "REDIRECT", + .revision = 0, + }; + socklen_t rev_len = sizeof(rev); + + // Revision 0 exists. + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallSucceeds()); + EXPECT_EQ(rev.revision, 0); + + // Revisions > 0 don't exist. + rev.revision = 1; + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallFailsWithErrno(EPROTONOSUPPORT)); +} + +// This tests the initial state of a machine with empty ip6tables via +// getsockopt(IP6T_SO_GET_INFO). We don't have a guarantee that the iptables are +// empty when running in native, but we can test that gVisor has the same +// initial state that a newly-booted Linux machine would have. +TEST(IP6TablesTest, InitialInfo) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)); + + // Get info via sockopt. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT( + getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // The nat table supports PREROUTING, and OUTPUT. + unsigned int valid_hooks = + (1 << NF_IP6_PRE_ROUTING) | (1 << NF_IP6_LOCAL_OUT) | + (1 << NF_IP6_POST_ROUTING) | (1 << NF_IP6_LOCAL_IN); + EXPECT_EQ(info.valid_hooks, valid_hooks); + + // Each chain consists of an empty entry with a standard target.. + EXPECT_EQ(info.hook_entry[NF_IP6_PRE_ROUTING], 0); + EXPECT_EQ(info.hook_entry[NF_IP6_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.hook_entry[NF_IP6_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.hook_entry[NF_IP6_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // The underflow points are the same as the entry points. + EXPECT_EQ(info.underflow[NF_IP6_PRE_ROUTING], 0); + EXPECT_EQ(info.underflow[NF_IP6_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.underflow[NF_IP6_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.underflow[NF_IP6_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // One entry for each chain, plus an error entry at the end. + EXPECT_EQ(info.num_entries, 5); + + EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); + EXPECT_EQ(strcmp(info.name, kNatTablename), 0); +} + +// This tests the initial state of a machine with empty ip6tables via +// getsockopt(IP6T_SO_GET_ENTRIES). We don't have a guarantee that the iptables +// are empty when running in native, but we can test that gVisor has the same +// initial state that a newly-booted Linux machine would have. +TEST(IP6TablesTest, InitialEntries) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)); + + // Get info via sockopt. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT( + getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // Use info to get entries. + socklen_t entries_size = sizeof(struct ip6t_get_entries) + info.size; + struct ip6t_get_entries* entries = + static_cast<struct ip6t_get_entries*>(malloc(entries_size)); + snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + entries->size = info.size; + ASSERT_THAT(getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_ENTRIES, entries, + &entries_size), + SyscallSucceeds()); + + // Verify the name and size. + ASSERT_EQ(info.size, entries->size); + ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); + + // Verify that the entrytable is 4 entries with accept targets and no matches + // followed by a single error target. + size_t entry_offset = 0; + while (entry_offset < entries->size) { + struct ip6t_entry* entry = reinterpret_cast<struct ip6t_entry*>( + reinterpret_cast<char*>(entries->entrytable) + entry_offset); + + // ipv6 should be zeroed. + struct ip6t_ip6 zeroed = {}; + ASSERT_EQ(memcmp(static_cast<void*>(&zeroed), + static_cast<void*>(&entry->ipv6), sizeof(zeroed)), + 0); + + // target_offset should be zero. + EXPECT_EQ(entry->target_offset, sizeof(ip6t_entry)); + + if (entry_offset < kEmptyStandardEntrySize * 4) { + // The first 4 entries are standard targets + struct xt_standard_target* target = + reinterpret_cast<struct xt_standard_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + // This is what's returned for an accept verdict. I don't know why. + EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); + } else { + // The last entry is an error target + struct xt_error_target* target = + reinterpret_cast<struct xt_error_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); + } + + entry_offset += entry->next_offset; + break; + } + + free(entries); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc index b8e4ece64..7ee10bbde 100644 --- a/test/syscalls/linux/iptables.cc +++ b/test/syscalls/linux/iptables.cc @@ -67,12 +67,82 @@ TEST(IPTablesBasic, FailSockoptNonRaw) { struct ipt_getinfo info = {}; snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); socklen_t info_size = sizeof(info); - EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + EXPECT_THAT(getsockopt(sock, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), SyscallFailsWithErrno(ENOPROTOOPT)); ASSERT_THAT(close(sock), SyscallSucceeds()); } +TEST(IPTablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + ASSERT_THAT(getsockopt(sock, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IPTablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ipt_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + ASSERT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IPTablesBasic, OriginalDstErrors) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_STREAM, 0), SyscallSucceeds()); + + // Sockets not affected by NAT should fail to find an original destination. + struct sockaddr_in addr = {}; + socklen_t addr_len = sizeof(addr); + EXPECT_THAT(getsockopt(sock, SOL_IP, SO_ORIGINAL_DST, &addr, &addr_len), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST(IPTablesBasic, GetRevision) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallSucceeds()); + + struct xt_get_revision rev = { + .name = "REDIRECT", + .revision = 0, + }; + socklen_t rev_len = sizeof(rev); + + // Revision 0 exists. + EXPECT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallSucceeds()); + EXPECT_EQ(rev.revision, 0); + + // Revisions > 0 don't exist. + rev.revision = 1; + EXPECT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallFailsWithErrno(EPROTONOSUPPORT)); +} + // Fixture for iptables tests. class IPTablesTest : public ::testing::Test { protected: @@ -112,7 +182,7 @@ TEST_F(IPTablesTest, InitialState) { struct ipt_getinfo info = {}; snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); socklen_t info_size = sizeof(info); - ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + ASSERT_THAT(getsockopt(s_, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), SyscallSucceeds()); // The nat table supports PREROUTING, and OUTPUT. @@ -148,7 +218,7 @@ TEST_F(IPTablesTest, InitialState) { snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); entries->size = info.size; ASSERT_THAT( - getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), + getsockopt(s_, SOL_IP, IPT_SO_GET_ENTRIES, entries, &entries_size), SyscallSucceeds()); // Verify the name and size. diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h index 0719c60a4..d0fc10fea 100644 --- a/test/syscalls/linux/iptables.h +++ b/test/syscalls/linux/iptables.h @@ -27,27 +27,32 @@ #include <linux/netfilter/x_tables.h> #include <linux/netfilter_ipv4.h> +#include <linux/netfilter_ipv6.h> #include <net/if.h> #include <netinet/ip.h> #include <stdint.h> +// +// IPv4 ABI. +// + #define ipt_standard_target xt_standard_target #define ipt_entry_target xt_entry_target #define ipt_error_target xt_error_target enum SockOpts { // For setsockopt. - BASE_CTL = 64, - SO_SET_REPLACE = BASE_CTL, - SO_SET_ADD_COUNTERS, - SO_SET_MAX = SO_SET_ADD_COUNTERS, + IPT_BASE_CTL = 64, + IPT_SO_SET_REPLACE = IPT_BASE_CTL, + IPT_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1, + IPT_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS, // For getsockopt. - SO_GET_INFO = BASE_CTL, - SO_GET_ENTRIES, - SO_GET_REVISION_MATCH, - SO_GET_REVISION_TARGET, - SO_GET_MAX = SO_GET_REVISION_TARGET + IPT_SO_GET_INFO = IPT_BASE_CTL, + IPT_SO_GET_ENTRIES = IPT_BASE_CTL + 1, + IPT_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 2, + IPT_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 3, + IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET }; // ipt_ip specifies basic matching criteria that can be applied by examining @@ -115,7 +120,7 @@ struct ipt_entry { unsigned char elems[0]; }; -// Passed to getsockopt(SO_GET_INFO). +// Passed to getsockopt(IPT_SO_GET_INFO). struct ipt_getinfo { // The name of the table. The user only fills this in, the rest is filled in // when returning from getsockopt. Currently "nat" and "mangle" are supported. @@ -127,7 +132,7 @@ struct ipt_getinfo { unsigned int valid_hooks; // The offset into the entry table for each valid hook. The entry table is - // returned by getsockopt(SO_GET_ENTRIES). + // returned by getsockopt(IPT_SO_GET_ENTRIES). unsigned int hook_entry[NF_IP_NUMHOOKS]; // For each valid hook, the underflow is the offset into the entry table to @@ -142,14 +147,14 @@ struct ipt_getinfo { unsigned int underflow[NF_IP_NUMHOOKS]; // The number of entries in the entry table returned by - // getsockopt(SO_GET_ENTRIES). + // getsockopt(IPT_SO_GET_ENTRIES). unsigned int num_entries; - // The size of the entry table returned by getsockopt(SO_GET_ENTRIES). + // The size of the entry table returned by getsockopt(IPT_SO_GET_ENTRIES). unsigned int size; }; -// Passed to getsockopt(SO_GET_ENTRIES). +// Passed to getsockopt(IPT_SO_GET_ENTRIES). struct ipt_get_entries { // The name of the table. The user fills this in. Currently "nat" and "mangle" // are supported. @@ -195,4 +200,103 @@ struct ipt_replace { struct ipt_entry entries[0]; }; +// +// IPv6 ABI. +// + +enum SockOpts6 { + // For setsockopt. + IP6T_BASE_CTL = 64, + IP6T_SO_SET_REPLACE = IP6T_BASE_CTL, + IP6T_SO_SET_ADD_COUNTERS = IP6T_BASE_CTL + 1, + IP6T_SO_SET_MAX = IP6T_SO_SET_ADD_COUNTERS, + + // For getsockopt. + IP6T_SO_GET_INFO = IP6T_BASE_CTL, + IP6T_SO_GET_ENTRIES = IP6T_BASE_CTL + 1, + IP6T_SO_GET_REVISION_MATCH = IP6T_BASE_CTL + 4, + IP6T_SO_GET_REVISION_TARGET = IP6T_BASE_CTL + 5, + IP6T_SO_GET_MAX = IP6T_SO_GET_REVISION_TARGET +}; + +// ip6t_ip6 specifies basic matching criteria that can be applied by examining +// only the IP header of a packet. +struct ip6t_ip6 { + // Source IP address. + struct in6_addr src; + + // Destination IP address. + struct in6_addr dst; + + // Source IP address mask. + struct in6_addr smsk; + + // Destination IP address mask. + struct in6_addr dmsk; + + // Input interface. + char iniface[IFNAMSIZ]; + + // Output interface. + char outiface[IFNAMSIZ]; + + // Input interface mask. + unsigned char iniface_mask[IFNAMSIZ]; + + // Output interface mask. + unsigned char outiface_mask[IFNAMSIZ]; + + // Transport protocol. + uint16_t proto; + + // TOS. + uint8_t tos; + + // Flags. + uint8_t flags; + + // Inverse flags. + uint8_t invflags; +}; + +// ip6t_entry is an ip6tables rule. +struct ip6t_entry { + // Basic matching information used to match a packet's IP header. + struct ip6t_ip6 ipv6; + + // A caching field that isn't used by userspace. + unsigned int nfcache; + + // The number of bytes between the start of this entry and the rule's target. + uint16_t target_offset; + + // The total size of this rule, from the beginning of the entry to the end of + // the target. + uint16_t next_offset; + + // A return pointer not used by userspace. + unsigned int comefrom; + + // Counters for packets and bytes, which we don't yet implement. + struct xt_counters counters; + + // The data for all this rules matches followed by the target. This runs + // beyond the value of sizeof(struct ip6t_entry). + unsigned char elems[0]; +}; + +// Passed to getsockopt(IP6T_SO_GET_ENTRIES). +struct ip6t_get_entries { + // The name of the table. + char name[XT_TABLE_MAXNAMELEN]; + + // The size of the entry table in bytes. The user fills this in with the value + // from struct ipt_getinfo.size. + unsigned int size; + + // The entries for the given table. This will run past the size defined by + // sizeof(struct ip6t_get_entries). + struct ip6t_entry entrytable[0]; +}; + #endif // GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_ diff --git a/test/syscalls/linux/kcov.cc b/test/syscalls/linux/kcov.cc new file mode 100644 index 000000000..6816c1fd0 --- /dev/null +++ b/test/syscalls/linux/kcov.cc @@ -0,0 +1,184 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/errno.h> +#include <sys/ioctl.h> +#include <sys/mman.h> + +#include <atomic> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// For this set of tests to run, they must be run with coverage enabled. On +// native Linux, this involves compiling the kernel with kcov enabled. For +// gVisor, we need to enable the Go coverage tool, e.g. bazel test -- +// collect_coverage_data --instrumentation_filter=//pkg/... <test>. + +constexpr char kcovPath[] = "/sys/kernel/debug/kcov"; +constexpr int kSize = 4096; +constexpr int KCOV_INIT_TRACE = 0x80086301; +constexpr int KCOV_ENABLE = 0x6364; +constexpr int KCOV_DISABLE = 0x6365; + +uint64_t* KcovMmap(int fd) { + return (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); +} + +TEST(KcovTest, Kcov) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + + for (int i = 0; i < 10; i++) { + // Make some syscalls to generate coverage data. + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallFailsWithErrno(EINVAL)); + } + + uint64_t num_pcs = *(uint64_t*)(area); + EXPECT_GT(num_pcs, 0); + for (uint64_t i = 1; i <= num_pcs; i++) { + // Verify that PCs are in the standard kernel range. + EXPECT_GT(area[i], 0xffffffff7fffffffL); + } + + ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); +} + +TEST(KcovTest, PrematureMmap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + // Cannot mmap before KCOV_INIT_TRACE. + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area == MAP_FAILED); +} + +// Tests that multiple kcov fds can be used simultaneously. +TEST(KcovTest, MultipleFds) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd1; + ASSERT_THAT(fd1 = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + + int fd2; + ASSERT_THAT(fd2 = open(kcovPath, O_RDWR), SyscallSucceeds()); + auto fd_closer = Cleanup([fd1, fd2]() { + close(fd1); + close(fd2); + }); + + auto t1 = ScopedThread([&] { + ASSERT_THAT(ioctl(fd1, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd1); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd1, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + ASSERT_THAT(ioctl(fd2, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd2); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd2, KCOV_ENABLE, 0), SyscallSucceeds()); +} + +// Tests behavior for two threads trying to use the same kcov fd. +TEST(KcovTest, MultipleThreads) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + // Test the behavior of multiple threads trying to use the same kcov fd + // simultaneously. + std::atomic<bool> t1_enabled(false), t1_disabled(false), t2_failed(false), + t2_exited(false); + auto t1 = ScopedThread([&] { + ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + t1_enabled = true; + + // After t2 has made sure that enabling kcov again fails, disable it. + while (!t2_failed) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); + t1_disabled = true; + + // Wait for t2 to enable kcov and then exit, after which we should be able + // to enable kcov again, without needing to set up a new memory mapping. + while (!t2_exited) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + auto t2 = ScopedThread([&] { + // Wait for t1 to enable kcov, and make sure that enabling kcov again fails. + while (!t1_enabled) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallFailsWithErrno(EINVAL)); + t2_failed = true; + + // Wait for t1 to disable kcov, after which using fd should now succeed. + while (!t1_disabled) { + sched_yield(); + } + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + t2.Join(); + t2_exited = true; +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index f8b7f7938..4a450742b 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -14,12 +14,10 @@ #include <errno.h> #include <fcntl.h> -#include <linux/magic.h> #include <linux/memfd.h> #include <linux/unistd.h> #include <string.h> #include <sys/mman.h> -#include <sys/statfs.h> #include <sys/syscall.h> #include <vector> @@ -53,6 +51,7 @@ namespace { #define F_SEAL_GROW 0x0004 #define F_SEAL_WRITE 0x0008 +using ::gvisor::testing::IsTmpfs; using ::testing::StartsWith; const std::string kMemfdName = "some-memfd"; @@ -444,20 +443,6 @@ TEST(MemfdTest, SealsAreInodeLevelProperties) { EXPECT_THAT(ftruncate(memfd3.get(), kPageSize), SyscallFailsWithErrno(EPERM)); } -PosixErrorOr<bool> IsTmpfs(const std::string& path) { - struct statfs stat; - if (statfs(path.c_str(), &stat)) { - if (errno == ENOENT) { - // Nothing at path, don't raise this as an error. Instead, just report no - // tmpfs at path. - return false; - } - return PosixError(errno, - absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); - } - return stat.f_type == TMPFS_MAGIC; -} - // Tmpfs files also support seals, but are created with F_SEAL_SEAL. TEST(MemfdTest, TmpfsFilesHaveSealSeal) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp"))); diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index 4036a9275..27758203d 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -82,6 +82,13 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { SyscallFailsWithErrno(EACCES)); } +TEST_F(MkdirTest, MkdirAtEmptyPath) { + ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirname_, O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(mkdirat(fd.get(), "", 0777), SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index 4c45766c7..ae65d366b 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -14,7 +14,9 @@ #include <errno.h> #include <fcntl.h> +#include <sys/socket.h> #include <sys/stat.h> +#include <sys/types.h> #include <sys/un.h> #include <unistd.h> @@ -39,7 +41,28 @@ TEST(MknodTest, RegularFile) { EXPECT_THAT(mknod(node1.c_str(), 0, 0), SyscallSucceeds()); } -TEST(MknodTest, MknodAtRegularFile) { +TEST(MknodTest, RegularFilePermissions) { + const std::string node = NewTempAbsPath(); + mode_t newUmask = 0077; + umask(newUmask); + + // Attempt to open file with mode 0777. Not specifying file type should create + // a regualar file. + mode_t perms = S_IRWXU | S_IRWXG | S_IRWXO; + EXPECT_THAT(mknod(node.c_str(), perms, 0), SyscallSucceeds()); + + // In the absence of a default ACL, the permissions of the created node are + // (mode & ~umask). -- mknod(2) + mode_t wantPerms = perms & ~newUmask; + struct stat st; + ASSERT_THAT(stat(node.c_str(), &st), SyscallSucceeds()); + ASSERT_EQ(st.st_mode & 0777, wantPerms); + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + ASSERT_EQ(st.st_mode & S_IFMT, S_IFREG); +} + +TEST(MknodTest, MknodAtFIFO) { const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string fifo_relpath = NewTempRelPath(); const std::string fifo = JoinPath(dir.path(), fifo_relpath); @@ -72,7 +95,7 @@ TEST(MknodTest, MknodOnExistingPathFails) { TEST(MknodTest, UnimplementedTypesReturnError) { const std::string path = NewTempAbsPath(); - if (IsRunningOnGvisor()) { + if (IsRunningWithVFS1()) { ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0), SyscallFailsWithErrno(EOPNOTSUPP)); } @@ -81,6 +104,27 @@ TEST(MknodTest, UnimplementedTypesReturnError) { ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM)); } +TEST(MknodTest, Socket) { + SKIP_IF(IsRunningOnGvisor() && IsRunningWithVFS1()); + + ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); + + auto filename = NewTempRelPath(); + + ASSERT_THAT(mknod(filename.c_str(), S_IFSOCK | S_IRUSR | S_IWUSR, 0), + SyscallSucceeds()); + + int sk; + ASSERT_THAT(sk = socket(AF_UNIX, SOCK_SEQPACKET, 0), SyscallSucceeds()); + FileDescriptor fd(sk); + + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + absl::SNPrintF(addr.sun_path, sizeof(addr.sun_path), "%s", filename.c_str()); + ASSERT_THAT(connect(sk, (struct sockaddr *)&addr, sizeof(addr)), + SyscallFailsWithErrno(ECONNREFUSED)); + ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds()); +} + TEST(MknodTest, Fifo) { const std::string fifo = NewTempAbsPath(); ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), @@ -162,6 +206,14 @@ TEST(MknodTest, FifoTruncNoOp) { EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(MknodTest, MknodAtEmptyPath) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(mknodat(fd.get(), "", S_IFREG | 0777, 0), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 6d3227ab6..e52c9cbcb 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -43,6 +43,8 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +using ::testing::AnyOf; +using ::testing::Eq; using ::testing::Gt; namespace gvisor { @@ -296,7 +298,8 @@ TEST_F(MMapTest, MapDevZeroSegfaultAfterUnmap) { }; EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); + IsPosixErrorOkAndHolds(AnyOf(Eq(W_EXITCODE(0, SIGSEGV)), + Eq(W_EXITCODE(0, 128 + SIGSEGV))))); } TEST_F(MMapTest, MapDevZeroUnaligned) { diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index a3e9745cf..3aab25b23 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -147,8 +147,15 @@ TEST(MountTest, UmountDetach) { // Unmount the tmpfs. mount.Release()(); - const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after2.st_ino); + // Only check for inode number equality if the directory is not in overlayfs. + // If xino option is not enabled and if all overlayfs layers do not belong to + // the same filesystem then "the value of st_ino for directory objects may not + // be persistent and could change even while the overlay filesystem is + // mounted." -- Documentation/filesystems/overlayfs.txt + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); + EXPECT_EQ(before.st_ino, after2.st_ino); + } // Can still read file after unmounting. std::vector<char> buf(sizeof(kContents)); @@ -213,8 +220,15 @@ TEST(MountTest, MountTmpfs) { } // Now that dir is unmounted again, we should have the old inode back. - const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after.st_ino); + // Only check for inode number equality if the directory is not in overlayfs. + // If xino option is not enabled and if all overlayfs layers do not belong to + // the same filesystem then "the value of st_ino for directory objects may not + // be persistent and could change even while the overlay filesystem is + // mounted." -- Documentation/filesystems/overlayfs.txt + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); + EXPECT_EQ(before.st_ino, after.st_ino); + } } TEST(MountTest, MountTmpfsMagicValIgnored) { @@ -321,6 +335,42 @@ TEST(MountTest, RenameRemoveMountPoint) { ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } +TEST(MountTest, MountFuseFilesystemNoDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + // Before kernel version 4.16-rc6, FUSE mount is protected by + // capable(CAP_SYS_ADMIN). After this version, it uses + // ns_capable(CAP_SYS_ADMIN) to protect. Before the 4.16 kernel, it was not + // allowed to mount fuse file systems without the global CAP_SYS_ADMIN. + int res = mount("", dir.path().c_str(), "fuse", 0, ""); + SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); + + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(MountTest, MountFuseFilesystem) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); + std::string mopts = "fd=" + std::to_string(fd.get()); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + // See comments in MountFuseFilesystemNoDevice for the reason why we skip + // EPERM when running on Linux. + int res = mount("", dir.path().c_str(), "fuse", 0, ""); + SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); + + auto const mount = + ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index bb7d108e8..77f390f3c 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -27,6 +27,7 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -95,6 +96,38 @@ TEST_F(OpenTest, OTruncAndReadOnlyFile) { Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); } +TEST_F(OpenTest, OCreateDirectory) { + SKIP_IF(IsRunningWithVFS1()); + auto dirpath = GetAbsoluteTestTmpdir(); + + // Normal case: existing directory. + ASSERT_THAT(open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // Trailing separator on existing directory. + ASSERT_THAT(open(dirpath.append("/").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // Trailing separator on non-existing directory. + ASSERT_THAT(open(JoinPath(dirpath, "non-existent").append("/").c_str(), + O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // "." special case. + ASSERT_THAT(open(JoinPath(dirpath, ".").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, MustCreateExisting) { + auto dirPath = GetAbsoluteTestTmpdir(); + + // Existing directory. + ASSERT_THAT(open(dirPath.c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + SyscallFailsWithErrno(EEXIST)); + + // Existing file. + auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dirPath)); + ASSERT_THAT(open(newFile.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + SyscallFailsWithErrno(EEXIST)); +} + TEST_F(OpenTest, ReadOnly) { char buf; const FileDescriptor ro_file = @@ -115,6 +148,26 @@ TEST_F(OpenTest, WriteOnly) { EXPECT_THAT(write(wo_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); } +TEST_F(OpenTest, CreateWithAppend) { + std::string data = "text"; + std::string new_file = NewTempAbsPath(); + const FileDescriptor file = ASSERT_NO_ERRNO_AND_VALUE( + Open(new_file, O_WRONLY | O_APPEND | O_CREAT, 0666)); + EXPECT_THAT(write(file.get(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(file.get(), 0, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT(write(file.get(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // Check that the size of the file is correct and that the offset has been + // incremented to that size. + struct stat s0; + EXPECT_THAT(fstat(file.get(), &s0), SyscallSucceeds()); + EXPECT_EQ(s0.st_size, 2 * data.size()); + EXPECT_THAT(lseek(file.get(), 0, SEEK_CUR), + SyscallSucceedsWithValue(2 * data.size())); +} + TEST_F(OpenTest, ReadWrite) { char buf; const FileDescriptor rw_file = @@ -235,7 +288,7 @@ TEST_F(OpenTest, AppendOnly) { ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND)); EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - // Then try to write to the first file and make sure the bytes are appended. + // Then try to write to the first fd and make sure the bytes are appended. EXPECT_THAT(WriteFd(fd1.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -247,7 +300,7 @@ TEST_F(OpenTest, AppendOnly) { EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(kBufSize * 2)); - // Then try to write to the second file and make sure the bytes are appended. + // Then try to write to the second fd and make sure the bytes are appended. EXPECT_THAT(WriteFd(fd2.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -356,6 +409,13 @@ TEST_F(OpenTest, FileNotDirectory) { SyscallFailsWithErrno(ENOTDIR)); } +TEST_F(OpenTest, SymlinkDirectory) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string link = NewTempAbsPath(); + ASSERT_THAT(symlink(dir.path().c_str(), link.c_str()), SyscallSucceeds()); + ASSERT_NO_ERRNO(Open(link, O_RDONLY | O_DIRECTORY)); +} + TEST_F(OpenTest, Null) { char c = '\0'; ASSERT_THAT(open(&c, O_RDONLY), SyscallFailsWithErrno(ENOENT)); diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 51eacf3f2..78c36f98f 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -88,21 +88,21 @@ TEST(CreateTest, CreateExclusively) { SyscallFailsWithErrno(EEXIST)); } -TEST(CreateTeast, CreatWithOTrunc) { +TEST(CreateTest, CreatWithOTrunc) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } -TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { +TEST(CreateTest, CreatDirWithOTruncAndReadOnly) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } -TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { +TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); int dirfd; ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), @@ -149,6 +149,116 @@ TEST(CreateTest, OpenCreateROThenRW) { EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1)); } +TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) { + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be + // cleared for the same reason. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); + + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + // Cannot restore after making permissions more restrictive. + const DisableSave ds; + ASSERT_THAT(fchmod(rfd.get(), 0200), SyscallSucceeds()); + + EXPECT_THAT(open(file.path().c_str(), O_RDONLY), + SyscallFailsWithErrno(EACCES)); + + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) { + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0200)); + + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + + // Cannot restore after making permissions more restrictive. + const DisableSave ds; + ASSERT_THAT(fchmod(wfd.get(), 0400), SyscallSucceeds()); + + EXPECT_THAT(open(file.path().c_str(), O_WRONLY), + SyscallFailsWithErrno(EACCES)); + + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) { + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be + // cleared for the same reason. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + // Create and open a file with read flag but without read permissions. + const std::string path = NewTempAbsPath(); + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_RDONLY, 0222)); + + EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + const FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_WRONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) { + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + // Create and open a file with write flag but without write permissions. + const std::string path = NewTempAbsPath(); + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_WRONLY, 0444)); + + EXPECT_THAT(open(path.c_str(), O_WRONLY), SyscallFailsWithErrno(EACCES)); + const FileDescriptor rfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 5ac68feb4..861617ff7 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -188,11 +188,12 @@ void ReceiveMessage(int sock, int ifindex) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); @@ -233,7 +234,7 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -343,7 +344,7 @@ TEST_P(CookedPacketTest, BindReceive) { } // Double Bind socket. -TEST_P(CookedPacketTest, DoubleBind) { +TEST_P(CookedPacketTest, DoubleBindSucceeds) { struct sockaddr_ll bind_addr = {}; bind_addr.sll_family = AF_PACKET; bind_addr.sll_protocol = htons(GetParam()); @@ -354,12 +355,11 @@ TEST_P(CookedPacketTest, DoubleBind) { SyscallSucceeds()); // Binding socket again should fail. - ASSERT_THAT( - bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched - // to EADDRINUSE. - AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it + // switched to EADDRINUSE. + SyscallSucceeds()); } // Bind and verify we do not receive data on interface which is not bound @@ -417,6 +417,122 @@ TEST_P(CookedPacketTest, BindDrop) { EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); } +// Verify that we receive outbound packets. This test requires at least one +// non loopback interface so that we can actually capture an outgoing packet. +TEST_P(CookedPacketTest, ReceiveOutbound) { + // Only ETH_P_ALL sockets can receive outbound packets on linux. + SKIP_IF(GetParam() != ETH_P_ALL); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index and name. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + int ifindex = ifr.ifr_ifindex; + + constexpr int kMACSize = 6; + char hwaddr[kMACSize]; + // Get interface address. + ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + ASSERT_THAT(ifr.ifr_hwaddr.sa_family, + AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER))); + memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize); + + // Just send it to the google dns server 8.8.8.8. It's UDP we don't care + // if it actually gets to the DNS Server we just want to see that we receive + // it on our AF_PACKET socket. + // + // NOTE: We just want to pick an IP that is non-local to avoid having to + // handle ARP as this should cause the UDP packet to be sent to the default + // gateway configured for the system under test. Otherwise the only packet we + // will see is the ARP query unless we picked an IP which will actually + // resolve. The test is a bit brittle but this was the best compromise for + // now. + struct sockaddr_in dest = {}; + ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket receives the data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1)); + + // Now read and check that the packet is the one we just sent. + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, ifindex); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING); + // Verify the link address of the interface matches that of the non + // non loopback interface address we stored above. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], hwaddr[i]); + } + + // Verify the IP header. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr); + EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + // Bind with invalid address. TEST_P(CookedPacketTest, BindFail) { // Null address. diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 4093ac813..b558e3a01 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -14,6 +14,7 @@ #include <arpa/inet.h> #include <linux/capability.h> +#include <linux/filter.h> #include <linux/if_arp.h> #include <linux/if_packet.h> #include <net/ethernet.h> @@ -192,11 +193,12 @@ TEST_P(RawPacketTest, Receive) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); @@ -236,7 +238,7 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -556,6 +558,112 @@ TEST_P(RawPacketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } +TEST_P(RawPacketTest, GetSocketError) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); +} + +TEST_P(RawPacketTest, GetSocketErrorBind) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + { + // Bind to the loopback device. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // SO_ERROR should return no errors. + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); + } + + { + // Now try binding to an invalid interface. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = 0xffff; // Just pick a really large number. + + // Binding should fail with EINVAL + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallFailsWithErrno(ENODEV)); + + // SO_ERROR does not return error when the device is invalid. + // On Linux there is just one odd ball condition where this can return + // an error where the device was valid and then removed or disabled + // between the first check for index and the actual registration of + // the packet endpoint. On Netstack this is not possible as the stack + // global mutex is held during registration and check. + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); + } +} + +TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + // + // gVisor returns no error on SO_DETACH_FILTER even if there is no filter + // attached unlike linux which does return ENOENT in such cases. This is + // because gVisor doesn't support SO_ATTACH_FILTER and just silently returns + // success. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawPacketTest, GetSocketDetachFilter) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +TEST_P(RawPacketTest, SetAndGetSocketLinger) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 34291850d..c097c9187 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -13,7 +13,9 @@ // limitations under the License. #include <fcntl.h> /* Obtain O_* constant definitions */ +#include <linux/magic.h> #include <sys/ioctl.h> +#include <sys/statfs.h> #include <sys/uio.h> #include <unistd.h> @@ -198,6 +200,16 @@ TEST_P(PipeTest, NonBlocking) { SyscallFailsWithErrno(EWOULDBLOCK)); } +TEST(PipeTest, StatFS) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + struct statfs st; + EXPECT_THAT(fstatfs(fds[0], &st), SyscallSucceeds()); + EXPECT_EQ(st.f_type, PIPEFS_MAGIC); + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + TEST(Pipe2Test, CloExec) { int fds[2]; ASSERT_THAT(pipe2(fds, O_CLOEXEC), SyscallSucceeds()); diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index 04c5161f5..f675dc430 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -153,7 +153,7 @@ TEST(PrctlTest, PDeathSig) { // Enable tracing, then raise SIGSTOP and expect our parent to suppress // it. TEST_CHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) >= 0); - raise(SIGSTOP); + TEST_CHECK(raise(SIGSTOP) == 0); // Sleep until killed by our parent death signal. sleep(3) is // async-signal-safe, absl::SleepFor isn't. while (true) { diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index d6b875dbf..e8fcc4439 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -16,6 +16,7 @@ #include <errno.h> #include <fcntl.h> #include <limits.h> +#include <linux/magic.h> #include <sched.h> #include <signal.h> #include <stddef.h> @@ -26,6 +27,7 @@ #include <sys/mman.h> #include <sys/prctl.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/utsname.h> #include <syscall.h> #include <unistd.h> @@ -45,6 +47,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" @@ -61,6 +64,7 @@ #include "test/util/fs_util.h" #include "test/util/memory_util.h" #include "test/util/posix_error.h" +#include "test/util/proc_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -670,6 +674,23 @@ TEST(ProcSelfMaps, Mprotect) { 3 * kPageSize, PROT_READ))); } +TEST(ProcSelfMaps, SharedAnon) { + const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ, MAP_SHARED | MAP_ANONYMOUS)); + + const auto proc_self_maps = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); + for (const auto& line : absl::StrSplit(proc_self_maps, '\n')) { + const auto entry = ASSERT_NO_ERRNO_AND_VALUE(ParseProcMapsLine(line)); + if (entry.start <= m.addr() && m.addr() < entry.end) { + // cf. proc(5), "/proc/[pid]/map_files/" + EXPECT_EQ(entry.filename, "/dev/zero (deleted)"); + return; + } + } + FAIL() << "no maps entry containing mapping at " << m.ptr(); +} + TEST(ProcSelfFd, OpenFd) { int pipe_fds[2]; ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); @@ -692,6 +713,30 @@ TEST(ProcSelfFd, OpenFd) { ASSERT_THAT(close(pipe_fds[1]), SyscallSucceeds()); } +static void CheckFdDirGetdentsDuplicates(const std::string& path) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_RDONLY | O_DIRECTORY)); + // Open a FD whose value is supposed to be much larger than + // the number of FDs opened by current process. + auto newfd = fcntl(fd.get(), F_DUPFD, 1024); + EXPECT_GE(newfd, 1024); + auto fd_closer = Cleanup([newfd]() { close(newfd); }); + auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir(path.c_str(), false)); + absl::node_hash_set<std::string> fd_files_dedup(fd_files.begin(), + fd_files.end()); + EXPECT_EQ(fd_files.size(), fd_files_dedup.size()); +} + +// This is a regression test for gvisor.dev/issues/3894 +TEST(ProcSelfFd, GetdentsDuplicates) { + CheckFdDirGetdentsDuplicates("/proc/self/fd"); +} + +// This is a regression test for gvisor.dev/issues/3894 +TEST(ProcSelfFdInfo, GetdentsDuplicates) { + CheckFdDirGetdentsDuplicates("/proc/self/fdinfo"); +} + TEST(ProcSelfFdInfo, CorrectFds) { // Make sure there is at least one open file. auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -735,8 +780,12 @@ TEST(ProcSelfFdInfo, Flags) { } TEST(ProcSelfExe, Absolute) { - auto exe = ASSERT_NO_ERRNO_AND_VALUE( - ReadLink(absl::StrCat("/proc/", getpid(), "/exe"))); + auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe")); + EXPECT_EQ(exe[0], '/'); +} + +TEST(ProcSelfCwd, Absolute) { + auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/cwd")); EXPECT_EQ(exe[0], '/'); } @@ -771,17 +820,12 @@ TEST(ProcCpuinfo, DeniesWriteNonRoot) { constexpr int kNobody = 65534; EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); - // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in - // kernfs. - if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { - EXPECT_THAT(truncate("/proc/cpuinfo", 123), - SyscallFailsWithErrno(EACCES)); - } + EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EACCES)); }); } // With root privileges, it is possible to open /proc/cpuinfo with write mode, -// but all write operations will return EIO. +// but all write operations should fail. TEST(ProcCpuinfo, DeniesWriteRoot) { // VFS1 does not behave differently for root/non-root. SKIP_IF(IsRunningWithVFS1()); @@ -790,16 +834,10 @@ TEST(ProcCpuinfo, DeniesWriteRoot) { int fd; EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds()); if (fd > 0) { - EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO)); - EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO)); - } - // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in - // kernfs. - if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { - if (fd > 0) { - EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO)); - } - EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO)); + // Truncate is not tested--it may succeed on some kernels without doing + // anything. + EXPECT_THAT(write(fd, "x", 1), SyscallFails()); + EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFails()); } } @@ -1439,6 +1477,16 @@ TEST(ProcPidExe, Subprocess) { EXPECT_EQ(actual, expected_absolute_path); } +// /proc/PID/cwd points to the correct directory. +TEST(ProcPidCwd, Subprocess) { + auto want = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); + + char got[PATH_MAX + 1] = {}; + ASSERT_THAT(ReadlinkWhileRunning("cwd", got, sizeof(got)), + SyscallSucceedsWithValue(Gt(0))); + EXPECT_EQ(got, want); +} + // Test whether /proc/PID/ files can be read for a running process. TEST(ProcPidFile, SubprocessRunning) { char buf[1]; @@ -2159,6 +2207,18 @@ TEST(Proc, PidTidIOAccounting) { noop.Join(); } +TEST(Proc, Statfs) { + struct statfs st; + EXPECT_THAT(statfs("/proc", &st), SyscallSucceeds()); + if (IsRunningWithVFS1()) { + EXPECT_EQ(st.f_type, ANON_INODE_FS_MAGIC); + } else { + EXPECT_EQ(st.f_type, PROC_SUPER_MAGIC); + } + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 3377b65cf..23677e296 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -39,6 +39,7 @@ namespace testing { namespace { constexpr const char kProcNet[] = "/proc/net"; +constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward"; TEST(ProcNetSymlinkTarget, FileMode) { struct stat s; @@ -477,6 +478,84 @@ TEST(ProcNetSnmp, CheckSnmp) { EXPECT_EQ(value_count, 1); } +TEST(ProcSysNetIpv4Recovery, Exists) { + EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_recovery", O_RDONLY), + SyscallSucceeds()); +} + +TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) { + // TODO(b/162988252): Enable save/restore for this test after the bug is + // fixed. + DisableSave ds; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/tcp_recovery", O_RDWR)); + + char buf[10] = {'\0'}; + char to_write = '2'; + + // Check initial value is set to 1. + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "1\n"), 0); + + // Set tcp_recovery to one of the allowed constants. + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "2\n"), 0); + + // Set tcp_recovery to any random value. + char kMessage[] = "100"; + EXPECT_THAT(PwriteFd(fd.get(), kMessage, strlen(kMessage), 0), + SyscallSucceedsWithValue(strlen(kMessage))); + EXPECT_THAT(PreadFd(fd.get(), buf, sizeof(kMessage), 0), + SyscallSucceedsWithValue(sizeof(kMessage))); + EXPECT_EQ(strcmp(buf, "100\n"), 0); +} + +TEST(ProcSysNetIpv4IpForward, Exists) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY)); +} + +TEST(ProcSysNetIpv4IpForward, DefaultValueEqZero) { + // Test is only valid in sandbox. Not hermetic in native tests + // running on a arbitrary machine. + SKIP_IF(!IsRunningOnGvisor()); + auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY)); + + char buf = 101; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_EQ(buf, '0') << "unexpected ip_forward: " << buf; +} + +TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDWR)); + + char buf; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_TRUE(buf == '0' || buf == '1') << "unexpected ip_forward: " << buf; + + // constexpr char to_write = '1'; + char to_write = (buf == '1') ? '0' : '1'; + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + + buf = 0; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_EQ(buf, to_write); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index f9392b9e0..0b174e2be 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -51,6 +51,7 @@ using ::testing::AnyOf; using ::testing::Contains; using ::testing::Eq; using ::testing::Not; +using SubprocessCallback = std::function<void()>; // Tests Unix98 pseudoterminals. // @@ -389,15 +390,15 @@ TEST(PtyTrunc, Truncate) { // (f)truncate should. FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); - int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + int n = ASSERT_NO_ERRNO_AND_VALUE(ReplicaID(master)); std::string spath = absl::StrCat("/dev/pts/", n); - FileDescriptor slave = + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(replica.get(), 0), SyscallFailsWithErrno(EINVAL)); } TEST(BasicPtyTest, StatUnopenedMaster) { @@ -453,16 +454,16 @@ void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) { EXPECT_EQ(expected, n); } -TEST(BasicPtyTest, OpenMasterSlave) { +TEST(BasicPtyTest, OpenMasterReplica) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); } -// The slave entry in /dev/pts/ disappears when the master is closed, even if -// the slave is still open. -TEST(BasicPtyTest, SlaveEntryGoneAfterMasterClose) { +// The replica entry in /dev/pts/ disappears when the master is closed, even if +// the replica is still open. +TEST(BasicPtyTest, ReplicaEntryGoneAfterMasterClose) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); // Get pty index. int index = -1; @@ -482,12 +483,12 @@ TEST(BasicPtyTest, Getdents) { FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index1 = -1; ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds()); - FileDescriptor slave1 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master1)); + FileDescriptor replica1 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master1)); FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index2 = -1; ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds()); - FileDescriptor slave2 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master2)); + FileDescriptor replica2 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master2)); // The directory contains ptmx, index1, and index2. (Plus any additional PTYs // unrelated to this test.) @@ -519,59 +520,60 @@ class PtyTest : public ::testing::Test { protected: void SetUp() override { master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_)); } void DisableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag &= ~ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } void EnableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag |= ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } - // Master and slave ends of the PTY. Non-blocking. + // Master and replica ends of the PTY. Non-blocking. FileDescriptor master_; - FileDescriptor slave_; + FileDescriptor replica_; }; -// Master to slave sanity test. -TEST_F(PtyTest, WriteMasterToSlave) { - // N.B. by default, the slave reads nothing until the master writes a newline. +// Master to replica sanity test. +TEST_F(PtyTest, WriteMasterToReplica) { + // N.B. by default, the replica reads nothing until the master writes a + // newline. constexpr char kBuf[] = "hello\n"; EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1), SyscallSucceedsWithValue(sizeof(kBuf) - 1)); - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be + // Linux moves data from the master to the replica via async work scheduled + // via tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". char buf[sizeof(kBuf)] = {}; - ExpectReadable(slave_, sizeof(buf) - 1, buf); + ExpectReadable(replica_, sizeof(buf) - 1, buf); EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0); } -// Slave to master sanity test. -TEST_F(PtyTest, WriteSlaveToMaster) { - // N.B. by default, the master reads nothing until the slave writes a newline, - // and the master gets a carriage return. +// Replica to master sanity test. +TEST_F(PtyTest, WriteReplicaToMaster) { + // N.B. by default, the master reads nothing until the replica writes a + // newline, and the master gets a carriage return. constexpr char kInput[] = "hello\n"; constexpr char kExpected[] = "hello\r\n"; - EXPECT_THAT(WriteFd(slave_.get(), kInput, sizeof(kInput) - 1), + EXPECT_THAT(WriteFd(replica_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be + // Linux moves data from the master to the replica via async work scheduled + // via tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". @@ -587,32 +589,33 @@ TEST_F(PtyTest, WriteInvalidUTF8) { SyscallSucceedsWithValue(sizeof(c))); } -// Both the master and slave report the standard default termios settings. +// Both the master and replica report the standard default termios settings. // -// Note that TCGETS on the master actually redirects to the slave (see comment +// Note that TCGETS on the master actually redirects to the replica (see comment // on MasterTermiosUnchangable). TEST_F(PtyTest, DefaultTermios) { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); } -// Changing termios from the master actually affects the slave. +// Changing termios from the master actually affects the replica. // -// TCSETS on the master actually redirects to the slave (see comment on +// TCSETS on the master actually redirects to the replica (see comment on // MasterTermiosUnchangable). -TEST_F(PtyTest, TermiosAffectsSlave) { +TEST_F(PtyTest, TermiosAffectsReplica) { struct kernel_termios master_termios = {}; EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); master_termios.c_lflag ^= ICANON; EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); - struct kernel_termios slave_termios = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &slave_termios), SyscallSucceeds()); - EXPECT_EQ(master_termios, slave_termios); + struct kernel_termios replica_termios = {}; + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &replica_termios), + SyscallSucceeds()); + EXPECT_EQ(master_termios, replica_termios); } // The master end of the pty has termios: @@ -627,7 +630,7 @@ TEST_F(PtyTest, TermiosAffectsSlave) { // // (From drivers/tty/pty.c:unix98_pty_init) // -// All termios control ioctls on the master actually redirect to the slave +// All termios control ioctls on the master actually redirect to the replica // (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the // master termios. // @@ -640,7 +643,7 @@ TEST_F(PtyTest, MasterTermiosUnchangable) { EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); ExpectReadable(master_, 1, &c); EXPECT_EQ(c, '\r'); // ICRNL had no effect! @@ -653,15 +656,15 @@ TEST_F(PtyTest, TermiosICRNL) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= ICRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\n'); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ONLCR rewrites output \n to \r\n. @@ -669,42 +672,42 @@ TEST_F(PtyTest, TermiosONLCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Extra byte for NUL for EXPECT_STREQ. char buf[3] = {}; ExpectReadable(master_, 2, buf); EXPECT_STREQ(buf, "\r\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosIGNCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. - ASSERT_THAT(PollAndReadFd(slave_.get(), &c, 1, kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), &c, 1, kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); } -// Test that we can successfully poll for readable data from the slave. -TEST_F(PtyTest, TermiosPollSlave) { +// Test that we can successfully poll for readable data from the replica. +TEST_F(PtyTest, TermiosPollReplica) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); absl::Notification notify; - int sfd = slave_.get(); + int sfd = replica_.get(); ScopedThread th([sfd, ¬ify]() { notify.Notify(); @@ -753,33 +756,33 @@ TEST_F(PtyTest, TermiosPollMaster) { absl::SleepFor(absl::Seconds(1)); char s[] = "foo\n"; - ASSERT_THAT(WriteFd(slave_.get(), s, strlen(s) + 1), SyscallSucceeds()); + ASSERT_THAT(WriteFd(replica_.get(), s, strlen(s) + 1), SyscallSucceeds()); } TEST_F(PtyTest, TermiosINLCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= INLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\r'); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosONOCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONOCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), @@ -789,7 +792,7 @@ TEST_F(PtyTest, TermiosONOCR) { // out of the other end. constexpr char kInput[] = "foo\r"; constexpr int kInputSize = sizeof(kInput) - 1; - ASSERT_THAT(WriteFd(slave_.get(), kInput, kInputSize), + ASSERT_THAT(WriteFd(replica_.get(), kInput, kInputSize), SyscallSucceedsWithValue(kInputSize)); char buf[kInputSize] = {}; @@ -800,7 +803,7 @@ TEST_F(PtyTest, TermiosONOCR) { ExpectFinished(master_); // Terminal should be at column 0 again, so no CR can be read. - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), @@ -811,11 +814,11 @@ TEST_F(PtyTest, TermiosOCRNL) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= OCRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); ExpectReadable(master_, 1, &c); EXPECT_EQ(c, '\n'); @@ -831,24 +834,24 @@ TEST_F(PtyTest, VEOLTermination) { ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); // Set the EOL character to '=' and write it. constexpr char delim = '='; struct kernel_termios t = DefaultTermios(); t.c_cc[VEOL] = delim; - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now we can read, as sending EOL caused the line to become available. - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_EQ(memcmp(buf, kInput, sizeof(kInput)), 0); - ExpectReadable(slave_, 1, buf); + ExpectReadable(replica_, 1, buf); EXPECT_EQ(buf[0], '='); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write more than the 4096 character limit, then a @@ -864,9 +867,9 @@ TEST_F(PtyTest, CanonBigWrite) { // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that data written in canonical mode can be read immediately once @@ -880,15 +883,15 @@ TEST_F(PtyTest, SwitchCanonToNoncanon) { // Nothing available yet. char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); DisableCanonical(); - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { @@ -901,10 +904,10 @@ TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { @@ -917,7 +920,7 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); // Wait for the input queue to fill. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); constexpr char delim = '\n'; ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); @@ -925,12 +928,12 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); // We can also read the remaining characters. - ExpectReadable(slave_, 6, buf); + ExpectReadable(replica_, 6, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { @@ -942,15 +945,15 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput) - 1)); EnableCanonical(); // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput) - 1, buf); + ExpectReadable(replica_, sizeof(kInput) - 1, buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { @@ -964,14 +967,14 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); EnableCanonical(); // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write over the 4095 noncanonical limit, then read out @@ -990,17 +993,17 @@ TEST_F(PtyTest, NoncanonBigWrite) { } // We should be able to read out everything. Sleep a bit so that Linux has a - // chance to move data from the master to the slave. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + // chance to move data from the master to the replica. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); for (int i = 0; i < kInputSize; i++) { // This makes too many syscalls for save/restore. const DisableSave ds; char c; - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); ASSERT_EQ(c, kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1015,18 +1018,18 @@ TEST_F(PtyTest, TermiosICANONNewline) { char buf[5] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = '\n'; ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(input) + 1)); - ExpectReadable(slave_, sizeof(input) + 1, buf); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(input) + 1)); + ExpectReadable(replica_, sizeof(input) + 1, buf); EXPECT_STREQ(buf, "abc\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1041,16 +1044,16 @@ TEST_F(PtyTest, TermiosICANONEOF) { char buf[4] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = ControlCharacter('D'); ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. Note that ^D is not included. - ExpectReadable(slave_, sizeof(input), buf); + ExpectReadable(replica_, sizeof(input), buf); EXPECT_STREQ(buf, "abc"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON limits us to 4096 bytes including a terminating character. Anything @@ -1076,12 +1079,12 @@ TEST_F(PtyTest, CanonDiscard) { // There should be multiple truncated lines available to read. for (int i = 0; i < kIter; i++) { char buf[kInputSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); EXPECT_EQ(buf[kMaxLineSize - 1], delim); EXPECT_EQ(buf[kMaxLineSize - 2], kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, CanonMultiline) { @@ -1096,15 +1099,15 @@ TEST_F(PtyTest, CanonMultiline) { // Get the first line. char line1[8] = {}; - ExpectReadable(slave_, sizeof(kInput1) - 1, line1); + ExpectReadable(replica_, sizeof(kInput1) - 1, line1); EXPECT_STREQ(line1, kInput1); // Get the second line. char line2[8] = {}; - ExpectReadable(slave_, sizeof(kInput2) - 1, line2); + ExpectReadable(replica_, sizeof(kInput2) - 1, line2); EXPECT_STREQ(line2, kInput2); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { @@ -1121,15 +1124,15 @@ TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { SyscallSucceedsWithValue(sizeof(kInput2) - 1)); ASSERT_NO_ERRNO( - WaitUntilReceived(slave_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); EnableCanonical(); // Get all together as one line. char line[9] = {}; - ExpectReadable(slave_, 8, line); + ExpectReadable(replica_, 8, line); EXPECT_STREQ(line, kExpected); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchTwiceMultiline) { @@ -1146,15 +1149,15 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { // All written characters have to make it into the input queue before // canonical mode is re-enabled. If the final '!' character hasn't been // enqueued before canonical mode is re-enabled, it won't be readable. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kExpected.size())); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kExpected.size())); EnableCanonical(); // Get all together as one line. char line[10] = {}; - ExpectReadable(slave_, 9, line); + ExpectReadable(replica_, 9, line); EXPECT_STREQ(line, kExpected.c_str()); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, QueueSize) { @@ -1162,7 +1165,7 @@ TEST_F(PtyTest, QueueSize) { constexpr char kInput1[] = "GO\n"; ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); // Ensure that writing more (beyond what is readable) does not impact the // readable size. @@ -1171,7 +1174,7 @@ TEST_F(PtyTest, QueueSize) { ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize), SyscallSucceedsWithValue(kMaxLineSize)); int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE( - WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); EXPECT_EQ(inputBufSize, sizeof(kInput1) - 1); } @@ -1196,9 +1199,9 @@ TEST_F(PtyTest, PartialBadBuffer) { EXPECT_THAT(WriteFd(master_.get(), kBuf, size), SyscallSucceedsWithValue(size)); - // Read from the slave into bad_buffer. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), size)); - EXPECT_THAT(ReadFd(slave_.get(), bad_buffer, size), + // Read from the replica into bad_buffer. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size)); + EXPECT_THAT(ReadFd(replica_.get(), bad_buffer, size), SyscallFailsWithErrno(EFAULT)); EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr; @@ -1218,16 +1221,16 @@ TEST_F(PtyTest, SimpleEcho) { TEST_F(PtyTest, GetWindowSize) { struct winsize ws; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); EXPECT_EQ(ws.ws_row, 0); EXPECT_EQ(ws.ws_col, 0); } -TEST_F(PtyTest, SetSlaveWindowSize) { +TEST_F(PtyTest, SetReplicaWindowSize) { constexpr uint16_t kRows = 343; constexpr uint16_t kCols = 2401; struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(slave_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws), @@ -1243,7 +1246,7 @@ TEST_F(PtyTest, SetMasterWindowSize) { ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &retrieved_ws), + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &retrieved_ws), SyscallSucceeds()); EXPECT_EQ(retrieved_ws.ws_row, kRows); EXPECT_EQ(retrieved_ws.ws_col, kCols); @@ -1253,7 +1256,7 @@ class JobControlTest : public ::testing::Test { protected: void SetUp() override { master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_)); // Make this a session leader, which also drops the controlling terminal. // In the gVisor test environment, this test will be run as the session @@ -1263,61 +1266,82 @@ class JobControlTest : public ::testing::Test { } } - // Master and slave ends of the PTY. Non-blocking. + PosixError RunInChild(SubprocessCallback childFunc) { + pid_t child = fork(); + if (!child) { + childFunc(); + _exit(0); + } + int wstatus; + if (waitpid(child, &wstatus, 0) != child) { + return PosixError( + errno, absl::StrCat("child failed with wait status: ", wstatus)); + } + return PosixError(wstatus, "process returned"); + } + + // Master and replica ends of the PTY. Non-blocking. FileDescriptor master_; - FileDescriptor slave_; + FileDescriptor replica_; }; TEST_F(JobControlTest, SetTTYMaster) { - ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(master_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(!replica_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYNonLeader) { // Fork a process that won't be the session leader. - pid_t child = fork(); - if (!child) { - // We shouldn't be able to set the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + auto res = + RunInChild([=]() { TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 0)); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYBadArg) { - // Despite the man page saying arg should be 0 here, Linux doesn't actually - // check. - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYDifferentSession) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Fork, join a new session, and try to steal the parent's controlling - // terminal, which should fail. - pid_t child = fork(); - if (!child) { + auto res = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - // We shouldn't be able to steal the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); - _exit(0); - } + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int gcwstatus; + TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild); + TEST_PCHECK(gcwstatus == 0); + }); } TEST_F(JobControlTest, ReleaseTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we // disconnect they TTY. @@ -1327,48 +1351,60 @@ TEST_F(JobControlTest, ReleaseTTY) { sigemptyset(&sa.sa_mask); struct sigaction old_sa; EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); } TEST_F(JobControlTest, ReleaseUnsetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + ASSERT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, ReleaseWrongTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(master_.get(), TIOCNOTTY) < 0 && errno == ENOTTY); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, ReleaseTTYNonLeader) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, ReleaseTTYDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t child = fork(); - if (!child) { - // Join a new session, then try to disconnect. + auto ret = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + pid_t grandchild = fork(); + if (!grandchild) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } // Used by the child process spawned in ReleaseTTYSignals to track received @@ -1387,7 +1423,7 @@ void sig_handler(int signum) { received |= signum; } // - Checks that thread 1 got both signals // - Checks that thread 2 didn't get any signals. TEST_F(JobControlTest, ReleaseTTYSignals) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); received = 0; struct sigaction sa = {}; @@ -1439,7 +1475,7 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { // Release the controlling terminal, sending SIGHUP and SIGCONT to all other // processes in this process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); @@ -1456,20 +1492,21 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { } TEST_F(JobControlTest, GetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - pid_t foreground_pgid; - pid_t pid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), - SyscallSucceeds()); - ASSERT_THAT(pid = getpid(), SyscallSucceeds()); - - ASSERT_EQ(foreground_pgid, pid); + auto res = RunInChild([=]() { + pid_t pid, foreground_pgid; + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + TEST_PCHECK(!ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid)); + TEST_PCHECK((pid = getpid()) >= 0); + TEST_PCHECK(pid == foreground_pgid); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // At this point there's no controlling terminal, so TIOCGPGRP should fail. pid_t foreground_pgid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + ASSERT_THAT(ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid), SyscallFailsWithErrno(ENOTTY)); } @@ -1479,113 +1516,125 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // - sets that child as the foreground process group // - kills its child and sets itself as the foreground process group. TEST_F(JobControlTest, SetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - sigaction(SIGTTOU, &sa, NULL); - - // Set ourself as the foreground process group. - ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); - - // Create a new process that just waits to be signaled. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!pause()); - // We should never reach this. - _exit(1); - } - - // Make the child its own process group, then make it the controlling process - // group of the terminal. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + TEST_PCHECK(!tcsetpgrp(replica_.get(), getpgid(0))); + + // Create a new process that just waits to be signaled. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } - // Sanity check - we're still the controlling session. - ASSERT_EQ(getsid(0), getsid(child)); + // Make the child its own process group, then make it the controlling + // process group of the terminal. + TEST_PCHECK(!setpgid(grandchild, grandchild)); + TEST_PCHECK(!tcsetpgrp(replica_.get(), grandchild)); - // Signal the child, wait for it to exit, then retake the terminal. - ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSIGNALED(wstatus)); - ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + // Sanity check - we're still the controlling session. + TEST_PCHECK(getsid(0) == getsid(grandchild)); - // Set ourself as the foreground process. - pid_t pgid; - ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); + // Signal the child, wait for it to exit, then retake the terminal. + TEST_PCHECK(!kill(grandchild, SIGTERM)); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFSIGNALED(wstatus)); + TEST_PCHECK(WTERMSIG(wstatus) == SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + TEST_PCHECK(pgid = getpgid(0) == 0); + TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid)); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { pid_t pid = getpid(); - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + ASSERT_THAT(ioctl(replica_.get(), TIOCSPGRP, &pid), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t pid = -1; - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), - SyscallFailsWithErrno(EINVAL)); + pid_t pid = -1; + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &pid) && errno == EINVAL); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Create a new process, put it in a new process group, make that group the - // foreground process group, then have the process wait. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!setpgid(0, 0)); - _exit(0); - } + auto ret = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } - // Wait for the child to exit. - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - // The child's process group doesn't exist anymore - this should fail. - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(ESRCH)); + // Wait for the child to exit. + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + // The child's process group doesn't exist anymore - this should fail. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 && + errno == ESRCH); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - int sync_setsid[2]; - int sync_exit[2]; - ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds()); - ASSERT_THAT(pipe(sync_exit), SyscallSucceeds()); + int sync_setsid[2]; + int sync_exit[2]; + TEST_PCHECK(pipe(sync_setsid) >= 0); + TEST_PCHECK(pipe(sync_exit) >= 0); - // Create a new process and put it in a new session. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(setsid() >= 0); - // Tell the parent we're in a new session. - char c = 'c'; - TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); - TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); - _exit(0); - } + // Create a new process and put it in a new session. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + char c = 'c'; + TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); + TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); + _exit(0); + } - // Wait for the child to tell us it's in a new session. - char c = 'c'; - ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1)); + // Wait for the child to tell us it's in a new session. + char c = 'c'; + TEST_PCHECK(ReadFd(sync_setsid[0], &c, 1) == 1); - // Child is in a new session, so we can't make it the foregroup process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(EPERM)); + // Child is in a new session, so we can't make it the foregroup process + // group. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) && + errno == EPERM); - EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1)); + TEST_PCHECK(WriteFd(sync_exit[1], &c, 1) == 1); - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(wstatus)); - EXPECT_EQ(WEXITSTATUS(wstatus), 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFEXITED(wstatus)); + TEST_PCHECK(!WEXITSTATUS(wstatus)); + }); + ASSERT_NO_ERRNO(ret); } // Verify that we don't hang when creating a new session from an orphaned diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 1d7dbefdb..4ac648729 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -50,10 +50,10 @@ TEST(JobControlRootTest, StealTTY) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); - // Make slave the controlling terminal. - ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + // Make replica the controlling terminal. + ASSERT_THAT(ioctl(replica.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Fork, join a new session, and try to steal the parent's controlling // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg @@ -62,9 +62,9 @@ TEST(JobControlRootTest, StealTTY) { if (!child) { ASSERT_THAT(setsid(), SyscallSucceeds()); // We shouldn't be able to steal the terminal with the wrong arg value. - TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(replica.get(), TIOCSCTTY, 0)); // We should be able to steal it if we are true root. - TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1)); + TEST_PCHECK(true_root == !ioctl(replica.get(), TIOCSCTTY, 1)); _exit(0); } diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 05c4ed03f..54709371c 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <linux/capability.h> +#include <linux/filter.h> #include <netinet/in.h> #include <netinet/ip.h> #include <netinet/ip6.h> @@ -21,6 +22,7 @@ #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> + #include <algorithm> #include "gtest/gtest.h" @@ -258,6 +260,27 @@ TEST_P(RawSocketTest, SendWithoutConnectFails) { SyscallFailsWithErrno(EDESTADDRREQ)); } +// Wildcard Bind. +TEST_P(RawSocketTest, BindToWildcard) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + struct sockaddr_storage addr; + addr = {}; + + // We don't set ports because raw sockets don't have a notion of ports. + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_any; + } + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); +} + // Bind to localhost. TEST_P(RawSocketTest, BindToLocalhost) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); @@ -790,10 +813,26 @@ void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); } -INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest, - ::testing::Combine( - ::testing::Values(IPPROTO_TCP, IPPROTO_UDP), - ::testing::Values(AF_INET, AF_INET6))); +TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} // AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. TEST(RawSocketTest, IPv6ProtoRaw) { @@ -813,6 +852,11 @@ TEST(RawSocketTest, IPv6ProtoRaw) { SyscallFailsWithErrno(EINVAL)); } +INSTANTIATE_TEST_SUITE_P( + AllInetTests, RawSocketTest, + ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP), + ::testing::Values(AF_INET, AF_INET6))); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index 0a27506aa..2f25aceb2 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -167,7 +167,7 @@ TEST_F(RawHDRINCL, NotReadable) { // nothing to be read. char buf[117]; ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EINVAL)); + SyscallFailsWithErrno(EAGAIN)); } // Test that we can connect() to a valid IP (loopback). @@ -178,6 +178,9 @@ TEST_F(RawHDRINCL, ConnectToLoopback) { } TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { + // FIXME(gvisor.dev/issue/3159): Test currently flaky. + SKIP_IF(true); + struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); @@ -273,14 +276,17 @@ TEST_F(RawHDRINCL, SendAndReceive) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr iphdr = {}; - memcpy(&iphdr, recv_buf, sizeof(iphdr)); - EXPECT_EQ(iphdr.id, 0); + // The packet ID should not be 0, as the packet has DF=0. + struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); + EXPECT_NE(iphdr->id, 0); } -// Send and receive a packet with nonzero IP ID. -TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { +// Send and receive a packet where the sendto address is not the same as the +// provided destination. +TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { + // FIXME(gvisor.dev/issue/3160): Test currently flaky. + SKIP_IF(true); + int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( @@ -292,19 +298,24 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); - // Construct a packet with an IP header, UDP header, and payload. Make the - // payload large enough to force an IP ID to be assigned. - constexpr char kPayload[128] = {}; + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "toto"; char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; ASSERT_TRUE( FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + // Overwrite the IP destination address with an IP we can't get to. + struct iphdr iphdr = {}; + memcpy(&iphdr, packet, sizeof(iphdr)); + iphdr.daddr = 42; + memcpy(packet, &iphdr, sizeof(iphdr)); socklen_t addrlen = sizeof(addr_); ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - // Receive the payload. + // Receive the payload, since sendto should replace the bad destination with + // localhost. char recv_buf[sizeof(packet)]; struct sockaddr_in src; socklen_t src_size = sizeof(src); @@ -318,47 +329,58 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should not be 0, as the packet was more than 68 bytes. - struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); - EXPECT_NE(iphdr->id, 0); + // The packet ID should not be 0, as the packet has DF=0. + struct iphdr recv_iphdr = {}; + memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); + EXPECT_NE(recv_iphdr.id, 0); + // The destination address should be localhost, not the bad IP we set + // initially. + EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); } -// Send and receive a packet where the sendto address is not the same as the -// provided destination. -TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { +// Send and receive a packet w/ the IP_HDRINCL option set. +TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) { int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); } - // IPPROTO_RAW sockets are write-only. We'll have to open another socket to - // read what we write. - FileDescriptor udp_sock = + FileDescriptor recv_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + FileDescriptor send_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + // Enable IP_HDRINCL option so that we can build and send w/ an IP + // header. + constexpr int kSockOptOn = 1; + ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // This is not strictly required but we do it to make sure that setting + // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving + // packets. + ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // Construct a packet with an IP header, UDP header, and payload. constexpr char kPayload[] = "toto"; char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; ASSERT_TRUE( FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); - // Overwrite the IP destination address with an IP we can't get to. - struct iphdr iphdr = {}; - memcpy(&iphdr, packet, sizeof(iphdr)); - iphdr.daddr = 42; - memcpy(packet, &iphdr, sizeof(iphdr)); socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, + ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0, reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - // Receive the payload, since sendto should replace the bad destination with - // localhost. + // Receive the payload. char recv_buf[sizeof(packet)]; struct sockaddr_in src; socklen_t src_size = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, + ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_size), SyscallSucceedsWithValue(sizeof(packet))); EXPECT_EQ( @@ -368,13 +390,20 @@ TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr recv_iphdr = {}; - memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); - EXPECT_EQ(recv_iphdr.id, 0); - // The destination address should be localhost, not the bad IP we set - // initially. - EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); + struct iphdr iphdr = {}; + memcpy(&iphdr, recv_buf, sizeof(iphdr)); + EXPECT_NE(iphdr.id, 0); + + // Also verify that the packet we just sent was not delivered to the + // IPPROTO_RAW socket. + { + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_size), + SyscallFailsWithErrno(EAGAIN)); + } } } // namespace diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 3de898df7..1b9dbc584 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -416,6 +416,28 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); } +// Set and get SO_LINGER. +TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { // We're going to receive both the echo request and reply, but the order is // indeterminate. diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc index 09703b5c1..71073bb3c 100644 --- a/test/syscalls/linux/readahead.cc +++ b/test/syscalls/linux/readahead.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -29,7 +30,15 @@ TEST(ReadaheadTest, InvalidFD) { EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF)); } +TEST(ReadaheadTest, UnsupportedFile) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + ASSERT_THAT(readahead(sock.get(), 1, 1), SyscallFailsWithErrno(EINVAL)); +} + TEST(ReadaheadTest, InvalidOffset) { + // This test is not valid for some Linux Kernels. + SKIP_IF(!IsRunningOnGvisor()); const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); @@ -79,6 +88,8 @@ TEST(ReadaheadTest, WriteOnly) { } TEST(ReadaheadTest, InvalidSize) { + // This test is not valid on some Linux kernels. + SKIP_IF(!IsRunningOnGvisor()); const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index 833c0dc4f..5458f54ad 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -170,6 +170,9 @@ TEST(RenameTest, FileOverwritesFile) { } TEST(RenameTest, DirectoryOverwritesDirectoryLinkCount) { + // Directory link counts are synthetic on overlay filesystems. + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + auto parent1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2)); diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc index 4bfb1ff56..94f9154a0 100644 --- a/test/syscalls/linux/rseq.cc +++ b/test/syscalls/linux/rseq.cc @@ -24,6 +24,7 @@ #include "test/syscalls/linux/rseq/uapi.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" +#include "test/util/posix_error.h" #include "test/util/test_util.h" namespace gvisor { @@ -31,6 +32,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + // Syscall test for rseq (restartable sequences). // // We must be very careful about how these tests are written. Each thread may @@ -98,7 +102,7 @@ void RunChildTest(std::string test_case, int want_status) { int status = 0; ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_EQ(status, want_status); + ASSERT_THAT(status, AnyOf(Eq(want_status), Eq(128 + want_status))); } // Test that rseq must be aligned. diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc index f036db26d..6f5d38bba 100644 --- a/test/syscalls/linux/rseq/rseq.cc +++ b/test/syscalls/linux/rseq/rseq.cc @@ -74,84 +74,95 @@ int TestUnaligned() { // Sanity test that registration works. int TestRegister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } return 0; -}; +} // Registration can't be done twice. int TestDoubleRegister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { + ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != EBUSY) { return 1; } return 0; -}; +} // Registration can be done again after unregister. int TestRegisterUnregister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); - sys_errno(ret) != 0) { + ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } return 0; -}; +} // The pointer to rseq must match on register/unregister. int TestUnregisterDifferentPtr() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } struct rseq r2 = {}; - if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); - sys_errno(ret) != EINVAL) { + + ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); + if (sys_errno(ret) != EINVAL) { return 1; } return 0; -}; +} // The signature must match on register/unregister. int TestUnregisterDifferentSignature() { constexpr int kSignature = 0; struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kSignature); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); - sys_errno(ret) != EPERM) { + ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); + if (sys_errno(ret) != EPERM) { return 1; } return 0; -}; +} // The CPU ID is initialized. int TestCPU() { struct rseq r = {}; r.cpu_id = kRseqCPUIDUninitialized; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } @@ -163,13 +174,13 @@ int TestCPU() { } return 0; -}; +} // Critical section is eventually aborted. int TestAbort() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -185,13 +196,13 @@ int TestAbort() { rseq_loop(&r, &cs); return 0; -}; +} // Abort may be before the critical section. int TestAbortBefore() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -207,13 +218,13 @@ int TestAbortBefore() { rseq_loop(&r, &cs); return 0; -}; +} // Signature must match. int TestAbortSignature() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + if (sys_errno(ret) != 0) { return 1; } @@ -229,13 +240,13 @@ int TestAbortSignature() { rseq_loop(&r, &cs); return 1; -}; +} // Abort must not be in the critical section. int TestAbortPreCommit() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + if (sys_errno(ret) != 0) { return 1; } @@ -251,13 +262,13 @@ int TestAbortPreCommit() { rseq_loop(&r, &cs); return 1; -}; +} // rseq.rseq_cs is cleared on abort. int TestAbortClearsCS() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -277,13 +288,13 @@ int TestAbortClearsCS() { } return 0; -}; +} // rseq.rseq_cs is cleared on abort outside of critical section. int TestInvalidAbortClearsCS() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -306,7 +317,7 @@ int TestInvalidAbortClearsCS() { } return 0; -}; +} // Exit codes: // 0 - Pass diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h index 3b7bb74b1..ff0dd6e48 100644 --- a/test/syscalls/linux/rseq/test.h +++ b/test/syscalls/linux/rseq/test.h @@ -20,22 +20,20 @@ namespace testing { // Test cases supported by rseq binary. -inline constexpr char kRseqTestUnaligned[] = "unaligned"; -inline constexpr char kRseqTestRegister[] = "register"; -inline constexpr char kRseqTestDoubleRegister[] = "double-register"; -inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; -inline constexpr char kRseqTestUnregisterDifferentPtr[] = - "unregister-different-ptr"; -inline constexpr char kRseqTestUnregisterDifferentSignature[] = +constexpr char kRseqTestUnaligned[] = "unaligned"; +constexpr char kRseqTestRegister[] = "register"; +constexpr char kRseqTestDoubleRegister[] = "double-register"; +constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; +constexpr char kRseqTestUnregisterDifferentPtr[] = "unregister-different-ptr"; +constexpr char kRseqTestUnregisterDifferentSignature[] = "unregister-different-signature"; -inline constexpr char kRseqTestCPU[] = "cpu"; -inline constexpr char kRseqTestAbort[] = "abort"; -inline constexpr char kRseqTestAbortBefore[] = "abort-before"; -inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; -inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; -inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; -inline constexpr char kRseqTestInvalidAbortClearsCS[] = - "invalid-abort-clears-cs"; +constexpr char kRseqTestCPU[] = "cpu"; +constexpr char kRseqTestAbort[] = "abort"; +constexpr char kRseqTestAbortBefore[] = "abort-before"; +constexpr char kRseqTestAbortSignature[] = "abort-signature"; +constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; +constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; +constexpr char kRseqTestInvalidAbortClearsCS[] = "invalid-abort-clears-cs"; } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 64123e904..a8bfb01f1 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -198,7 +198,39 @@ TEST(SendFileTest, SendAndUpdateFileOffset) { EXPECT_EQ(absl::string_view(kData, kHalfDataSize), absl::string_view(actual, bytes_sent)); - // Verify that the input file offset has been updated + // Verify that the input file offset has been updated. + ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), + SyscallSucceedsWithValue(kHalfDataSize)); + EXPECT_EQ( + absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent), + absl::string_view(actual, kHalfDataSize)); +} + +TEST(SendFileTest, SendToDevZeroAndUpdateFileOffset) { + // Create temp files. + // Test input string length must be > 2 AND even. + constexpr char kData[] = "The slings and arrows of outrageous fortune,"; + constexpr int kDataSize = sizeof(kData) - 1; + constexpr int kHalfDataSize = kDataSize / 2; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open /dev/zero as write only. + const FileDescriptor outf = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); + + // Send data and verify that sendfile returns the correct value. + int bytes_sent; + EXPECT_THAT( + bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize), + SyscallSucceedsWithValue(kHalfDataSize)); + + char actual[kHalfDataSize]; + // Verify that the input file offset has been updated. ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), SyscallSucceedsWithValue(kHalfDataSize)); EXPECT_EQ( @@ -250,7 +282,7 @@ TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) { EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize), absl::string_view(actual, bytes_sent)); - // Verify that the input file offset has been updated + // Verify that the input file offset has been updated. ASSERT_THAT(read(inf.get(), &actual, kQuarterDataSize), SyscallSucceedsWithValue(kQuarterDataSize)); @@ -501,6 +533,22 @@ TEST(SendFileTest, SendPipeWouldBlock) { SyscallFailsWithErrno(EWOULDBLOCK)); } +TEST(SendFileTest, SendPipeEOF) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, 123), + SyscallSucceedsWithValue(0)); +} + TEST(SendFileTest, SendPipeBlocks) { // Create temp file. constexpr char kData[] = diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index c7fdbb924..d6e8b3e59 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -29,6 +29,8 @@ namespace testing { namespace { using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; const uint64_t kAllocSize = kPageSize * 128ULL; @@ -394,7 +396,8 @@ TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) { }; EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); + IsPosixErrorOkAndHolds(AnyOf(Eq(W_EXITCODE(0, SIGSEGV)), + Eq(W_EXITCODE(0, 128 + SIGSEGV))))); } TEST(ShmTest, RequestingSegmentSmallerThanSHMMINFails) { diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index c20cd3fcc..e680d3dd7 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -14,6 +14,7 @@ #include <sys/socket.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -26,6 +27,9 @@ namespace gvisor { namespace testing { +// From linux/magic.h, but we can't depend on linux headers here. +#define SOCKFS_MAGIC 0x534F434B + TEST(SocketTest, UnixSocketPairProtocol) { int socks[2]; ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks), @@ -94,6 +98,19 @@ TEST(SocketTest, UnixSocketStat) { } } +TEST(SocketTest, UnixSocketStatFS) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + struct statfs st; + EXPECT_THAT(fstatfs(bound.get(), &st), SyscallSucceeds()); + EXPECT_EQ(st.f_type, SOCKFS_MAGIC); + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + using SocketOpenTest = ::testing::TestWithParam<int>; // UDS cannot be opened. diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index f7d6139f1..5d39e6fbd 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -462,7 +462,8 @@ TEST_P(AllSocketPairTest, SendTimeoutDefault) { TEST_P(AllSocketPairTest, SetGetSendTimeout) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - timeval tv = {.tv_sec = 89, .tv_usec = 42000}; + // tv_usec should be a multiple of 4000 to work on most systems. + timeval tv = {.tv_sec = 89, .tv_usec = 44000}; EXPECT_THAT( setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); @@ -472,8 +473,8 @@ TEST_P(AllSocketPairTest, SetGetSendTimeout) { EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &actual_tv, &len), SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv_sec, 89); - EXPECT_EQ(actual_tv.tv_usec, 42000); + EXPECT_EQ(actual_tv.tv_sec, tv.tv_sec); + EXPECT_EQ(actual_tv.tv_usec, tv.tv_usec); } TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { @@ -484,8 +485,9 @@ TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { int64_t extra_data; } ABSL_ATTRIBUTE_PACKED; + // tv_usec should be a multiple of 4000 to work on most systems. timeval_with_extra tv_extra = { - .tv = {.tv_sec = 0, .tv_usec = 123000}, + .tv = {.tv_sec = 0, .tv_usec = 124000}, }; EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, @@ -497,8 +499,8 @@ TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &actual_tv, &len), SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv.tv_sec, 0); - EXPECT_EQ(actual_tv.tv.tv_usec, 123000); + EXPECT_EQ(actual_tv.tv.tv_sec, tv_extra.tv.tv_sec); + EXPECT_EQ(actual_tv.tv.tv_usec, tv_extra.tv.tv_usec); } TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) { diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc index 6a232238d..6cd67123d 100644 --- a/test/syscalls/linux/socket_generic_stress.cc +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <poll.h> #include <stdio.h> #include <sys/ioctl.h> #include <sys/socket.h> @@ -29,6 +30,9 @@ namespace testing { using ConnectStressTest = SocketPairTest; TEST_P(ConnectStressTest, Reset65kTimes) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -37,6 +41,14 @@ TEST_P(ConnectStressTest, Reset65kTimes) { char sent_data[100] = {}; ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), SyscallSucceedsWithValue(sizeof(sent_data))); + // Poll the other FD to make sure that the data is in the receive buffer + // before closing it to ensure a RST is triggered. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->second_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); } } @@ -58,7 +70,54 @@ INSTANTIATE_TEST_SUITE_P( // a persistent listener (if applicable). using PersistentListenerConnectStressTest = SocketPairTest; -TEST_P(PersistentListenerConnectStressTest, 65kTimes) { +TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + if (GetParam().type == SOCK_STREAM) { + // Poll the other FD to make sure that we see the FIN from the other + // side before closing the second_fd. This ensures that the first_fd + // enters TIME-WAIT and not second_fd. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->second_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); + } +} + +TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); + if (GetParam().type == SOCK_STREAM) { + // Poll the other FD to make sure that we see the FIN from the other + // side before closing the first_fd. This ensures that the second_fd + // enters TIME-WAIT and not first_fd. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->first_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + } +} + +TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 18b9e4b70..11fcec443 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -97,11 +97,9 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { ASSERT_THAT(socketpair(AF_INET6, 0, 0, fd), SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - // Invalid AF will return ENOAFSUPPORT. - ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); - ASSERT_THAT(socketpair(8675309, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); + // Invalid AF will fail. + ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), SyscallFails()); + ASSERT_THAT(socketpair(8675309, 0, 0, fd), SyscallFails()); } enum class Operation { @@ -116,7 +114,8 @@ std::string OperationToString(Operation operation) { return "Bind"; case Operation::Connect: return "Connect"; - case Operation::SendTo: + // Operation::SendTo is the default. + default: return "SendTo"; } } @@ -861,36 +860,38 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { SyscallSucceedsWithValue(0)); } -// This test is disabled under random save as the the restore run -// results in the stack.Seed() being different which can cause -// sequence number of final connect to be one that is considered -// old and can cause the test to be flaky. -TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - +// setupTimeWaitClose sets up a socket endpoint in TIME_WAIT state. +// Callers can choose to perform active close on either ends of the connection +// and also specify if they want to enabled SO_REUSEADDR. +void setupTimeWaitClose(const TestAddress* listener, + const TestAddress* connector, bool reuse, + bool accept_close, sockaddr_storage* listen_addr, + sockaddr_storage* conn_bound_addr) { // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener->family(), SOCK_STREAM, IPPROTO_TCP)); + if (reuse) { + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + } + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(listen_addr), + listener->addr_len), SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; + socklen_t addrlen = listener->addr_len; ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + reinterpret_cast<sockaddr*>(listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener->family(), *listen_addr)); // Connect to the listening socket. FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + Socket(connector->family(), SOCK_STREAM, IPPROTO_TCP)); // We disable saves after this point as a S/R causes the netstack seed // to be regenerated which changes what ports/ISN is picked for a given @@ -901,11 +902,12 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { // // TODO(gvisor.dev/issue/940): S/R portSeed/portHint DisableSave ds; - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + sockaddr_storage conn_addr = connector->addr; + ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port)); ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), + connector->addr_len), SyscallSucceeds()); // Accept the connection. @@ -913,33 +915,146 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; + socklen_t conn_addrlen = connector->addr_len; ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(conn_bound_addr), &conn_addrlen), SyscallSucceeds()); - // close the accept FD to trigger TIME_WAIT on the accepted socket which - // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of - // TIME_WAIT. - accepted.reset(); - absl::SleepFor(absl::Seconds(1)); - conn_fd.reset(); + FileDescriptor active_closefd, passive_closefd; + if (accept_close) { + active_closefd = std::move(accepted); + passive_closefd = std::move(conn_fd); + } else { + active_closefd = std::move(conn_fd); + passive_closefd = std::move(accepted); + } + + // shutdown to trigger TIME_WAIT. + ASSERT_THAT(shutdown(active_closefd.get(), SHUT_RDWR), SyscallSucceeds()); + { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = passive_closefd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + } + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = active_closefd.get(), + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); + + passive_closefd.reset(); + t.Join(); + active_closefd.reset(); + // This sleep is needed to reduce flake to ensure that the passive-close + // ensures the state transitions to CLOSE from LAST_ACK. absl::SleepFor(absl::Seconds(1)); +} - // Now bind and connect a new socket and verify that we can immediately - // rebind the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); +// These tests are disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +// +// Test re-binding of client and server bound addresses when the older +// connection is in TIME_WAIT. +TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + // Now bind a new socket and verify that we can immediately rebind the address + // bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + param.listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, + TCPPassiveCloseNoTimeWaitReuseTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + param.listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Now bind and connect new socket and verify that we can immediately rebind + // the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr)); + sockaddr_storage conn_addr = param.connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), + param.connector.addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); } TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { @@ -996,6 +1111,86 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { EXPECT_EQ(get, kUserTimeout); } +TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + { + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + } + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Trigger a RST by turning linger off and closing the socket. + struct linger opt = { + .l_onoff = 1, + .l_linger = 0, + }; + ASSERT_THAT( + setsockopt(conn_fd.get(), SOL_SOCKET, SO_LINGER, &opt, sizeof(opt)), + SyscallSucceeds()); + ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); + + if (IsRunningOnGvisor()) { + // Gvisor packet procssing is asynchronous and can take a bit of time in + // some cases so we give it a bit of time to process the RST packet before + // calling accept. + // + // There is nothing to poll() on so we have no choice but to use a sleep + // here. + absl::SleepFor(absl::Milliseconds(100)); + } + + sockaddr_storage accept_addr; + socklen_t addrlen = sizeof(accept_addr); + + auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( + listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); + ASSERT_EQ(addrlen, listener.addr_len); + + // TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed. + if (IsRunningOnGvisor()) { + char buf[10]; + ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(ECONNRESET)); + } else { + int err; + socklen_t optlen = sizeof(err); + ASSERT_THAT( + getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + SyscallSucceeds()); + ASSERT_EQ(err, ECONNRESET); + ASSERT_EQ(optlen, sizeof(err)); + } +} + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not // saved. Enable S/R once issue is fixed. TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { @@ -2469,6 +2664,44 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { SyscallSucceeds()); } +TEST_P(SocketMultiProtocolInetLoopbackTest, + MultipleBindsAllowedNoListeningReuseAddr) { + const auto& param = GetParam(); + // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP + // this is only permitted if there is no other listening socket. + SKIP_IF(param.type != SOCK_STREAM); + // Bind the v4 loopback on a v4 socket. + const TestAddress& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Now create a socket and bind it to the same port, this should + // succeed since there is no listening socket for the same port. + FileDescriptor second_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(second_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(second_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { auto const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index 2324c7f6a..1a0b53394 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -82,8 +82,11 @@ using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; // This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral // ports are already in use for a given destination ip/port. +// // We disable S/R because this test creates a large number of sockets. -TEST_P(SocketInetLoopbackTest, TestTCPPortExhaustion_NoRandomSave) { +// +// FIXME(b/162475855): This test is failing reliably. +TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -165,6 +168,71 @@ INSTANTIATE_TEST_SUITE_P( TestParam{V6Loopback(), V6Loopback()}), DescribeTestParam); +struct ProtocolTestParam { + std::string description; + int type; +}; + +std::string DescribeProtocolTestParam( + ::testing::TestParamInfo<ProtocolTestParam> const& info) { + return info.param.description; +} + +using SocketMultiProtocolInetLoopbackTest = + ::testing::TestWithParam<ProtocolTestParam>; + +TEST_P(SocketMultiProtocolInetLoopbackTest, + BindAvoidsListeningPortsReuseAddr_NoRandomSave) { + const auto& param = GetParam(); + // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP + // this is only permitted if there is no other listening socket. + SKIP_IF(param.type != SOCK_STREAM); + + DisableSave ds; // Too many syscalls. + + // A map of port to file descriptor binding the port. + std::map<uint16_t, FileDescriptor> listen_sockets; + + // Exhaust all ephemeral ports. + while (true) { + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int ret = bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len); + if (ret != 0) { + ASSERT_EQ(errno, EADDRINUSE); + break; + } + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + uint16_t port = reinterpret_cast<sockaddr_in*>(&bound_addr)->sin_port; + + // Newly bound port should not already be in use by a listening socket. + ASSERT_EQ(listen_sockets.find(port), listen_sockets.end()); + auto fd = bound_fd.get(); + listen_sockets.insert(std::make_pair(port, std::move(bound_fd))); + ASSERT_THAT(listen(fd, SOMAXCONN), SyscallSucceeds()); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllFamilies, SocketMultiProtocolInetLoopbackTest, + ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, + ProtocolTestParam{"UDP", SOCK_DGRAM}), + DescribeProtocolTestParam); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index c2ecb639f..f4b69c46c 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -34,6 +34,9 @@ namespace gvisor { namespace testing { +using ::testing::AnyOf; +using ::testing::Eq; + TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -800,6 +803,9 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { // Linux and Netstack both default to a 60s TCP_LINGER2 timeout. constexpr int kDefaultTCPLingerTimeout = 60; +// On Linux, the maximum linger2 timeout was changed from 60sec to 120sec. +constexpr int kMaxTCPLingerTimeout = 120; +constexpr int kOldMaxTCPLingerTimeout = 60; TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -813,26 +819,45 @@ TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { EXPECT_EQ(get, kDefaultTCPLingerTimeout); } -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutLessThanZero) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, - sizeof(kZero)), - SyscallSucceedsWithValue(0)); - constexpr int kNegative = -1234; EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kNegative, sizeof(kNegative)), SyscallSucceedsWithValue(0)); + int get = INT_MAX; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, -1); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); } -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout // on linux (defaults to 60 seconds on linux). - constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + constexpr int kAboveDefault = kMaxTCPLingerTimeout + 1; EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kAboveDefault, sizeof(kAboveDefault)), SyscallSucceedsWithValue(0)); @@ -843,7 +868,12 @@ TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kDefaultTCPLingerTimeout); + if (IsRunningOnGvisor()) { + EXPECT_EQ(get, kMaxTCPLingerTimeout); + } else { + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); + } } TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { @@ -1050,5 +1080,124 @@ TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { } } +// Test setsockopt and getsockopt for a socket with SO_LINGER option. +TEST_P(TCPSocketPairTest, SetAndGetLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Check getsockopt before SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_THAT(got_len, sizeof(got_linger)); + struct linger want_linger = {}; + EXPECT_EQ(0, memcmp(&want_linger, &got_linger, got_len)); + + // Set and get SO_LINGER with negative values. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = -3; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(sl.l_onoff, got_linger.l_onoff); + // Linux returns a different value as it uses HZ to convert the seconds to + // jiffies which overflows for negative values. We want to be compatible with + // linux for getsockopt return value. + if (IsRunningOnGvisor()) { + EXPECT_EQ(sl.l_linger, got_linger.l_linger); + } + + // Set and get SO_LINGER option with positive values. + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test socket to disable SO_LINGER option. +TEST_P(TCPSocketPairTest, SetOffLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + sl.l_onoff = 0; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set to zero. + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test close on dup'd socket with SO_LINGER option set. +TEST_P(TCPSocketPairTest, CloseWithLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + FileDescriptor dupFd = FileDescriptor(dup(sockets->first_fd())); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + char buf[10] = {}; + // Write on dupFd should succeed as socket will not be closed until + // all references are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); + + // Close the socket. + dupFd.reset(); + // Write on dupFd should fail as all references for socket are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index edb86aded..3f2c0fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -435,8 +435,10 @@ TEST_P(UDPSocketPairTest, TOSRecvMismatch) { // Test that an IPv4 socket does not support the IPv6 TClass option. TEST_P(UDPSocketPairTest, TClassRecvMismatch) { - // This should only test AF_INET sockets for the mismatch behavior. - SKIP_IF(GetParam().domain != AF_INET); + // This should only test AF_INET6 sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET6); + // IPV6_RECVTCLASS is only valid for SOCK_DGRAM and SOCK_RAW. + SKIP_IF(GetParam().type != SOCK_DGRAM | GetParam().type != SOCK_RAW); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -448,5 +450,27 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) { SyscallFailsWithErrno(EOPNOTSUPP)); } +// Test the SO_LINGER option can be set/get on udp socket. +TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT( + getsockopt(sockets->first_fd(), level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 1c7b0cf90..8f7ccc868 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -217,6 +217,8 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { } TEST_P(IPUnboundSocketTest, CheckSkipECN) { + // Test is inconsistant on different kernels. + SKIP_IF(!IsRunningOnGvisor()); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); int set = 0xFF; socklen_t set_sz = sizeof(set); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index de0f5f01b..a72c76c97 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -27,6 +27,7 @@ #include "absl/memory/memory.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/posix_error.h" #include "test/util/test_util.h" namespace gvisor { @@ -73,9 +74,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that not setting a default send interface prevents multicast packets @@ -207,8 +208,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -262,8 +264,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -317,8 +320,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -372,8 +376,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -431,8 +436,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -490,8 +496,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -545,8 +552,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -600,8 +608,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -659,9 +668,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that multicast works when the default send interface is configured by @@ -717,9 +726,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that multicast works when the default send interface is configured by @@ -775,8 +784,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -834,8 +844,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -907,9 +918,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that dropping a group membership prevents multicast packets from being @@ -965,9 +976,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) { @@ -1319,9 +1330,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, + sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } } @@ -1398,9 +1409,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, + sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } } @@ -1421,9 +1432,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { char recv_buf[sizeof(send_buf)] = {}; for (auto& sockets : socket_pairs) { - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, - sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, + sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } } } @@ -1474,9 +1485,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1518,9 +1529,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that a socket can bind to a multicast address and still send out @@ -1568,9 +1579,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1615,9 +1626,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1666,9 +1677,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1726,17 +1737,17 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // of the other sockets to have received it, but we will check that later. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(send_buf))); + RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(send_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } // Verify that no other messages were received. for (auto& socket : sockets) { char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } } @@ -2113,45 +2124,12 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { // balancing (REUSEPORT) instead of the most recently bound socket // (REUSEADDR). char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(kMessageSize)); - EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(kMessageSize)); -} - -// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports. -// We disable S/R because this test creates a large number of sockets. -TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) { - auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - constexpr int kClients = 65536; - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(receiver1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Disable cooperative S/R as we are making too many syscalls. - DisableSave ds; - std::vector<std::unique_ptr<FileDescriptor>> sockets; - for (int i = 0; i < kClients; i++) { - auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len); - if (ret == 0) { - sockets.push_back(std::move(s)); - continue; - } - ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); - break; - } + EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); + EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); } // Test that socket will receive packet info control message. @@ -2452,5 +2430,105 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } + +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) { + auto sender_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind the first FD to the loopback. This is an alternative to + // IP_MULTICAST_IF for setting the default send interface. + auto sender_addr = V4Loopback(); + ASSERT_THAT( + bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Bind the second FD to the v4 any address to ensure that we can receive the + // multicast packet. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Register to receive IP packet info. + const int one = 1; + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_PKTINFO, &one, + sizeof(one)), + SyscallSucceeds()); + + // Send a multicast packet. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + msghdr recv_msg = {}; + iovec recv_iov = {}; + char recv_buf[sizeof(send_buf)]; + char recv_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + recv_iov.iov_base = recv_buf; + recv_iov.iov_len = sizeof(recv_buf); + recv_msg.msg_iov = &recv_iov; + recv_msg.msg_iovlen = 1; + recv_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + recv_msg.msg_control = recv_cmsg_buf; + ASSERT_THAT(RetryEINTR(recvmsg)(receiver_socket->get(), &recv_msg, 0), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + + // Check the IP_PKTINFO control message. + cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(receiver_socket->get(), SIOCGIFINDEX, &ifr), + SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + if (IsRunningOnGvisor()) { + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, group.imr_multiaddr.s_addr); + } else { + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + } + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, group.imr_multiaddr.s_addr); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index d690d9564..b206137eb 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -42,7 +42,9 @@ TestAddress V4EmptyAddress() { } void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { - got_if_infos_ = false; + // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its + // IPv4 address on eth0. + found_net_interfaces_ = false; // Get interface list. ASSERT_NO_ERRNO(if_helper_.Load()); @@ -71,7 +73,7 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { } eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr); - got_if_infos_ = true; + found_net_interfaces_ = true; } // Verifies that a newly instantiated UDP socket does not have the @@ -110,6 +112,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { // the destination port number. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastReceivedOnExpectedPort) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -185,9 +188,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // not a unicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastReceivedOnExpectedAddresses) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -272,6 +273,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // (UDPBroadcastSendRecvOnSocketBoundToAny). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastSendRecvOnSocketBoundToBroadcast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -313,6 +315,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // (UDPBroadcastSendRecvOnSocketBoundToBroadcast). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastSendRecvOnSocketBoundToAny) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -351,6 +354,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST // disabled. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Broadcast a test message without having enabled SO_BROADCAST on the sending @@ -401,6 +405,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // multicast on gVisor. SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!found_net_interfaces_); + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -435,6 +441,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Check that multicast packets will be delivered to the sending socket without // setting an interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { + SKIP_IF(!found_net_interfaces_); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -478,6 +485,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { // set interface and IP_MULTICAST_LOOP disabled. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelfLoopOff) { + SKIP_IF(!found_net_interfaces_); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -528,6 +536,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // multicast on gVisor. SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -566,6 +576,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // Check that multicast packets will be delivered to another socket without // setting an interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -613,6 +624,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { // set interface and IP_MULTICAST_LOOP disabled on the sending socket. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSenderNoLoop) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -664,6 +676,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastReceiverNoLoop) { + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -716,6 +730,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // and both will receive data on it when bound to the ANY address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToAny) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -782,6 +797,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // and both will receive data on it when bound to the multicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToMulticastAddress) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -851,6 +867,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // multicast address, both will receive data. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -924,6 +941,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // is not a multicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackFromAddr) { + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -991,9 +1010,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // interface, a multicast packet sent out uses the latter as its source address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackIfNicAndAddr) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); // Create receiver, bind to ANY and join the multicast group. auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1059,9 +1076,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // another interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); // FIXME (b/137790511): When bound to one interface it is not possible to set // IP_MULTICAST_IF to a different interface. diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h index 10b90b1e0..0e9e70e8e 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h @@ -29,9 +29,9 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { IfAddrHelper if_helper_; - // got_if_infos_ is set to false if SetUp() could not obtain all interface - // infos that we need. - bool got_if_infos_; + // found_net_interfaces_ is set to false if SetUp() could not obtain + // all interface infos that we need. + bool found_net_interfaces_; // Interface infos. int lo_if_idx_; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc new file mode 100644 index 000000000..8052bf404 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketNetlinkTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc new file mode 100644 index 000000000..bcbd2feac --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketNogotsanTest = SimpleSocketTest; + +// Check that connect returns EAGAIN when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketNogotsanTest, + UDPConnectPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + +// Check that bind returns EADDRINUSE when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + auto addr = V4Loopback(); + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = + bind(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketNogotsanTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc new file mode 100644 index 000000000..49a0f06d9 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -0,0 +1,225 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" + +#include <arpa/inet.h> +#include <poll.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/util/capability_util.h" +#include "test/util/cleanup.h" + +namespace gvisor { +namespace testing { + +constexpr size_t kSendBufSize = 200; + +// Checks that the loopback interface considers itself bound to all IPs in an +// associated subnet. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcv_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Send from an unassigned address but an address that is in the subnet + // associated with the loopback interface. + TestAddress sender_addr("V4NotAssignd1"); + sender_addr.addr.ss_family = AF_INET; + sender_addr.addr_len = sizeof(sockaddr_in); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2", + &(reinterpret_cast<sockaddr_in*>(&sender_addr.addr) + ->sin_addr.s_addr))); + ASSERT_THAT( + bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Send the packet to an unassigned address but an address that is in the + // subnet associated with the loopback interface. + TestAddress receiver_addr("V4NotAssigned2"); + receiver_addr.addr.ss_family = AF_INET; + receiver_addr.addr_len = sizeof(sockaddr_in); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254", + &(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr) + ->sin_addr.s_addr))); + ASSERT_THAT( + bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(rcv_sock->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len); + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + ASSERT_THAT( + RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceedsWithValue(kSendBufSize)); + + // Check that we received the packet. + char recv_buf[kSendBufSize] = {}; + ASSERT_THAT(RetryEINTR(recv)(rcv_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)); + ASSERT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)); +} + +// Tests that broadcast packets are delivered to all interested sockets +// (wildcard and broadcast address specified sockets). +// +// Note, we cannot test the IPv4 Broadcast (255.255.255.255) because we do +// not have a route to it. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { + constexpr uint16_t kPort = 9876; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + // Number of sockets per socket type. + constexpr int kNumSocketsPerType = 2; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + 24 /* prefixlen */, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + TestAddress broadcast_address("SubnetBroadcastAddress"); + broadcast_address.addr.ss_family = AF_INET; + broadcast_address.addr_len = sizeof(sockaddr_in); + auto broadcast_address_in = + reinterpret_cast<sockaddr_in*>(&broadcast_address.addr); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.255", + &broadcast_address_in->sin_addr.s_addr)); + broadcast_address_in->sin_port = htons(kPort); + + TestAddress any_address = V4Any(); + reinterpret_cast<sockaddr_in*>(&any_address.addr)->sin_port = htons(kPort); + + // We create sockets bound to both the wildcard address and the broadcast + // address to make sure both of these types of "broadcast interested" sockets + // receive broadcast packets. + std::vector<std::unique_ptr<FileDescriptor>> socks; + for (bool bind_wildcard : {false, true}) { + // Create multiple sockets for each type of "broadcast interested" + // socket so we can test that all sockets receive the broadcast packet. + for (int i = 0; i < kNumSocketsPerType; i++) { + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto idx = socks.size(); + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + if (bind_wildcard) { + ASSERT_THAT( + bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr), + any_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } else { + ASSERT_THAT(bind(sock->get(), + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } + + socks.push_back(std::move(sock)); + } + } + + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + + // Broadcasts from each socket should be received by every socket (including + // the sending socket). + for (int w = 0; w < socks.size(); w++) { + auto& w_sock = socks[w]; + ASSERT_THAT( + RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "]"; + + // Check that we received the packet on all sockets. + for (int r = 0; r < socks.size(); r++) { + auto& r_sock = socks[r]; + + struct pollfd poll_fd = {r_sock->get(), POLLIN, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)) + << "write socks[" << w << "] & read socks[" << r << "]"; + + char recv_buf[kSendBufSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(r_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + } + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h new file mode 100644 index 000000000..73e7836d5 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketNetlinkTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc new file mode 100644 index 000000000..17021ff82 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv6UDPSockets, IPv6UDPUnboundSocketNetlinkTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv6UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc new file mode 100644 index 000000000..2ee218231 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h" + +#include <arpa/inet.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/util/capability_util.h" + +namespace gvisor { +namespace testing { + +// Checks that the loopback interface does not consider itself bound to all IPs +// in an associated subnet. +TEST_P(IPv6UDPUnboundSocketNetlinkTest, JoinSubnet) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in6_addr addr; + EXPECT_EQ(1, inet_pton(AF_INET6, "2001:db8::1", &addr)); + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/64, &addr, sizeof(addr))); + + // Binding to an unassigned address but an address that is in the subnet + // associated with the loopback interface should fail. + TestAddress sender_addr("V6NotAssignd1"); + sender_addr.addr.ss_family = AF_INET6; + sender_addr.addr_len = sizeof(sockaddr_in6); + EXPECT_EQ(1, inet_pton(AF_INET6, "2001:db8::2", + reinterpret_cast<sockaddr_in6*>(&sender_addr.addr) + ->sin6_addr.s6_addr)); + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(bind(sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h new file mode 100644 index 000000000..88098be82 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv6 UDP sockets. +using IPv6UDPUnboundSocketNetlinkTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 15d4b85a7..5f8d7f981 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <linux/ethtool.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> #include <linux/sockios.h> @@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) { // Check that the loopback is zero hardware address. ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0); @@ -178,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) { EXPECT_GT(ifr.ifr_mtu, 0); } +TEST(NetdeviceTest, EthtoolGetTSInfo) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ethtool_ts_info tsi = {}; + tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities. + + // Prepare the request. + struct ifreq ifr = {}; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + ifr.ifr_data = (void*)&tsi; + + // Check that SIOCGIFMTU returns a nonzero MTU. + if (IsRunningOnGvisor()) { + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), + SyscallFailsWithErrno(EOPNOTSUPP)); + return; + } + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index e6647a1c3..b3fcf8e7c 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -577,7 +577,10 @@ TEST(NetlinkRouteTest, GetRouteDump) { std::cout << std::endl; - if (msg->rtm_table == RT_TABLE_MAIN) { + // If the test is running in a new network namespace, it will have only + // the local route table. + if (msg->rtm_table == RT_TABLE_MAIN || + (!IsRunningOnGvisor() && msg->rtm_table == RT_TABLE_LOCAL)) { routeFound = true; dstFound = rtDstFound && dstFound; } diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index bde1dbb4d..7a0bad4cb 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -26,6 +26,62 @@ namespace { constexpr uint32_t kSeq = 12345; +// Types of address modifications that may be performed on an interface. +enum class LinkAddrModification { + kAdd, + kDelete, +}; + +// Populates |hdr| with appripriate values for the modification type. +PosixError PopulateNlmsghdr(LinkAddrModification modification, + struct nlmsghdr* hdr) { + switch (modification) { + case LinkAddrModification::kAdd: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + return NoError(); + case LinkAddrModification::kDelete: + hdr->nlmsg_type = RTM_DELADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + return NoError(); + } + + return PosixError(EINVAL); +} + +// Adds or removes the specified address from the specified interface. +PosixError LinkModifyLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen, + LinkAddrModification modification) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + PosixError err = PopulateNlmsghdr(modification, &req.hdr); + if (!err.ok()) { + return err; + } + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + } // namespace PosixError DumpLinks( @@ -84,31 +140,14 @@ PosixErrorOr<Link> LoopbackLink() { PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifaddr; - char attrbuf[512]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifaddr.ifa_index = index; - req.ifaddr.ifa_family = family; - req.ifaddr.ifa_prefixlen = prefixlen; - - struct rtattr* rta = reinterpret_cast<struct rtattr*>( - reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); - rta->rta_type = IFA_LOCAL; - rta->rta_len = RTA_LENGTH(addrlen); - req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); - memcpy(RTA_DATA(rta), addr, addrlen); + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kAdd); +} - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kDelete); } PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index 149c4a7f6..e5badca70 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -43,6 +43,10 @@ PosixErrorOr<Link> LoopbackLink(); PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); +// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface. +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + // LinkChangeFlags changes interface flags. E.g. IFF_UP. PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 53b678e94..e11792309 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -753,6 +753,20 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { return ret; } +PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, + int timeout) { + fd_set rfd; + struct timeval to = {.tv_sec = timeout, .tv_usec = 0}; + FD_ZERO(&rfd); + FD_SET(sock, &rfd); + + int ret; + RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to)); + RETURN_ERROR_IF_SYSCALL_FAIL( + ret = RetryEINTR(recv)(sock, buf, buf_size, MSG_DONTWAIT)); + return ret; +} + void RecvNoData(int sock) { char data = 0; struct iovec iov; diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 734b48b96..468bc96e0 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -467,6 +467,10 @@ PosixError FreeAvailablePort(int port); // SendMsg converts a buffer to an iovec and adds it to msg before sending it. PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size); +// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock. +PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, + int timeout); + // RecvNoData checks that no data is receivable on sock. void RecvNoData(int sock); diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 99e77b89e..1edcb15a7 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -103,6 +103,24 @@ TEST_P(StreamUnixSocketPairTest, Sendto) { SyscallFailsWithErrno(EISCONN)); } +TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct linger sl = {1, 5}; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&got_linger, &sl, length)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index 08fc4b1b7..a1d2b9b11 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -298,6 +298,23 @@ TEST(SpliceTest, ToPipe) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, ToPipeEOF) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Splice from the empty file to the pipe. + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, 123, 0), + SyscallSucceedsWithValue(0)); +} + TEST(SpliceTest, ToPipeOffset) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -342,7 +359,7 @@ TEST(SpliceTest, FromPipe) { ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); - // Open the input file. + // Open the output file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); @@ -364,6 +381,40 @@ TEST(SpliceTest, FromPipe) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, FromPipeMultiple) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + std::string buf = "abcABC123"; + ASSERT_THAT(write(wfd.get(), buf.c_str(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + + // Open the output file. + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor out_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); + + // Splice from the pipe to the output file over several calls. + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + + // Reset cursor to zero so that we can check the contents. + ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + + // Contents should be equal. + std::vector<char> rbuf(buf.size()); + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), buf.c_str(), buf.size()), 0); +} + TEST(SpliceTest, FromPipeOffset) { // Create a new pipe. int fds[2]; @@ -693,6 +744,34 @@ TEST(SpliceTest, FromPipeMaxFileSize) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, FromPipeToDevZero) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + const FileDescriptor zero = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); + + // Close the write end to prevent blocking below. + wfd.reset(); + + // Splice to /dev/zero. The first call should empty the pipe, and the return + // value should not exceed the number of bytes available for reading. + EXPECT_THAT( + splice(rfd.get(), nullptr, zero.get(), nullptr, kPageSize + 123, 0), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_THAT(splice(rfd.get(), nullptr, zero.get(), nullptr, 1, 0), + SyscallSucceedsWithValue(0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 2503960f3..92260b1e1 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -97,6 +97,11 @@ TEST_F(StatTest, FstatatSymlink) { } TEST_F(StatTest, Nlinks) { + // Skip this test if we are testing overlayfs because overlayfs does not + // (intentionally) return the correct nlink value for directories. + // See fs/overlayfs/inode.c:ovl_getattr(). + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Directory is initially empty, it should contain 2 links (one from itself, @@ -328,20 +333,23 @@ TEST_F(StatTest, LeadingDoubleSlash) { // Test that a rename doesn't change the underlying file. TEST_F(StatTest, StatDoesntChangeAfterRename) { - const TempPath old_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath old_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const TempPath new_path(NewTempAbsPath()); struct stat st_old = {}; struct stat st_new = {}; - ASSERT_THAT(stat(old_dir.path().c_str(), &st_old), SyscallSucceeds()); - ASSERT_THAT(rename(old_dir.path().c_str(), new_path.path().c_str()), + ASSERT_THAT(stat(old_file.path().c_str(), &st_old), SyscallSucceeds()); + ASSERT_THAT(rename(old_file.path().c_str(), new_path.path().c_str()), SyscallSucceeds()); ASSERT_THAT(stat(new_path.path().c_str(), &st_new), SyscallSucceeds()); EXPECT_EQ(st_old.st_nlink, st_new.st_nlink); EXPECT_EQ(st_old.st_dev, st_new.st_dev); - EXPECT_EQ(st_old.st_ino, st_new.st_ino); + // Overlay filesystems may synthesize directory inode numbers on the fly. + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { + EXPECT_EQ(st_old.st_ino, st_new.st_ino); + } EXPECT_EQ(st_old.st_mode, st_new.st_mode); EXPECT_EQ(st_old.st_uid, st_new.st_uid); EXPECT_EQ(st_old.st_gid, st_new.st_gid); @@ -378,7 +386,9 @@ TEST_F(StatTest, LinkCountsWithRegularFileChild) { // This test verifies that inodes remain around when there is an open fd // after link count hits 0. -TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { +// +// It is marked NoSave because we don't support saving unlinked files. +TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoSave) { // Setting the enviornment variable GVISOR_GOFER_UNCACHED to any value // will prevent this test from running, see the tmpfs lifecycle. // @@ -387,9 +397,6 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { const char* uncached_gofer = getenv("GVISOR_GOFER_UNCACHED"); SKIP_IF(uncached_gofer != nullptr); - // We don't support saving unlinked files. - const DisableSave ds; - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( dir.path(), "hello", TempPath::kDefaultFileMode)); @@ -432,6 +439,11 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { // Test link counts with a directory as the child. TEST_F(StatTest, LinkCountsWithDirChild) { + // Skip this test if we are testing overlayfs because overlayfs does not + // (intentionally) return the correct nlink value for directories. + // See fs/overlayfs/inode.c:ovl_getattr(). + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Before a child is added the two links are "." and the link from the parent. diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc index aca51d30f..f0fb166bd 100644 --- a/test/syscalls/linux/statfs.cc +++ b/test/syscalls/linux/statfs.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/magic.h> #include <sys/statfs.h> #include <unistd.h> @@ -43,14 +44,10 @@ TEST(StatfsTest, InternalTmpfs) { TEST(StatfsTest, InternalDevShm) { struct statfs st; EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); -} - -TEST(StatfsTest, NameLen) { - struct statfs st; - EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); // This assumes that /dev/shm is tmpfs. - EXPECT_EQ(st.f_namelen, NAME_MAX); + // Note: We could be an overlay on some configurations. + EXPECT_TRUE(st.f_type == TMPFS_MAGIC || st.f_type == OVERLAYFS_SUPER_MAGIC); } TEST(FstatfsTest, CannotStatBadFd) { diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index a17ff62e9..4d9eba7f0 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -218,6 +218,36 @@ TEST(SymlinkTest, PreadFromSymlink) { EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); } +TEST(SymlinkTest, PwriteToSymlink) { + std::string name = NewTempAbsPath(); + int fd; + ASSERT_THAT(fd = open(name.c_str(), O_CREAT, 0644), SyscallSucceeds()); + ASSERT_THAT(close(fd), SyscallSucceeds()); + + std::string linkname = NewTempAbsPath(); + ASSERT_THAT(symlink(name.c_str(), linkname.c_str()), SyscallSucceeds()); + + ASSERT_THAT(fd = open(linkname.c_str(), O_WRONLY), SyscallSucceeds()); + + const int data_size = 10; + const std::string data = std::string(data_size, 'a'); + EXPECT_THAT(pwrite64(fd, data.c_str(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); + + ASSERT_THAT(close(fd), SyscallSucceeds()); + ASSERT_THAT(fd = open(name.c_str(), O_RDONLY), SyscallSucceeds()); + + char buf[data_size + 1]; + EXPECT_THAT(pread64(fd, buf, data.size(), 0), SyscallSucceeds()); + buf[data.size()] = '\0'; + EXPECT_STREQ(buf, data.c_str()); + + ASSERT_THAT(close(fd), SyscallSucceeds()); + + EXPECT_THAT(unlink(name.c_str()), SyscallSucceeds()); + EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); +} + TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { // Drop capabilities that allow us to override file and directory permissions. ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); @@ -297,6 +327,16 @@ TEST(SymlinkTest, FollowUpdatesATime) { EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime); } +TEST(SymlinkTest, SymlinkAtEmptyPath) { + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(symlinkat(file.path().c_str(), fd.get(), ""), + SyscallFailsWithErrno(ENOENT)); +} + class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {}; // Test that creating an existing symlink with creat will create the target. diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index a4d2953e1..e0981e28a 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -13,6 +13,9 @@ // limitations under the License. #include <fcntl.h> +#ifdef __linux__ +#include <linux/filter.h> +#endif // __linux__ #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -717,6 +720,30 @@ TEST_P(TcpSocketTest, TcpSCMPriority) { ASSERT_EQ(cmsg, nullptr); } +TEST_P(TcpSocketTest, TimeWaitPollHUP) { + shutdown(s_, SHUT_RDWR); + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = s_, + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); + shutdown(t_, SHUT_RDWR); + t.Join(); + // At this point s_ should be in TIME-WAIT and polling for POLLHUP should + // return with 1 FD. + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = s_, + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); @@ -1559,6 +1586,93 @@ TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { } } +#ifdef __linux__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // Program generated using sudo tcpdump -i lo tcp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000006}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000006}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +#endif // __linux__ + +TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + constexpr int kLingerTimeout = 10; // Seconds. + + // Set the SO_LINGER option. + struct linger sl = { + .l_onoff = 1, + .l_linger = kLingerTimeout, + }; + ASSERT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + struct pollfd poll_fd = { + .fd = s.get(), + .events = POLLHUP, + }; + constexpr int kPollTimeoutMs = 0; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + auto const start_time = absl::Now(); + EXPECT_THAT(close(s.release()), SyscallSucceeds()); + auto const end_time = absl::Now(); + + // Close() should not linger and return immediately. + ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index c988c6380..bfc95ed38 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -196,6 +196,26 @@ TEST(TruncateTest, FtruncateNonWriteable) { EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +// ftruncate(2) should succeed as long as the file descriptor is writeable, +// regardless of whether the file permissions allow writing. +TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) { + // Drop capabilities that allow us to override file permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + auto path = NewTempAbsPath(); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT, 0444)); + + // In goferfs, ftruncate may be converted to a remote truncate operation that + // unavoidably requires write permission. + SKIP_IF(IsRunningOnGvisor() && !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(path))); + ASSERT_THAT(ftruncate(fd.get(), 100), SyscallSucceeds()); +} + TEST(TruncateTest, TruncateNonExist) { EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT)); } diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 7a8ac30a4..1a7673317 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,13 +12,1845 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/udp_socket_test_cases.h" +#include <arpa/inet.h> +#include <fcntl.h> + +#include <ctime> + +#ifdef __linux__ +#include <linux/errqueue.h> +#include <linux/filter.h> +#endif // __linux__ +#include <netinet/in.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "absl/strings/str_format.h" +#ifndef SIOCGSTAMP +#include <linux/sockios.h> +#endif + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { namespace { +// Fixture for tests parameterized by the address family to use (AF_INET and +// AF_INET6) when creating sockets. +class UdpSocketTest + : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { + protected: + // Creates two sockets that will be used by test cases. + void SetUp() override; + + // Binds the socket bind_ to the loopback and updates bind_addr_. + PosixError BindLoopback(); + + // Binds the socket bind_ to Any and updates bind_addr_. + PosixError BindAny(); + + // Binds given socket to address addr and updates. + PosixError BindSocket(int socket, struct sockaddr* addr); + + // Return initialized Any address to port 0. + struct sockaddr_storage InetAnyAddr(); + + // Return initialized Loopback address to port 0. + struct sockaddr_storage InetLoopbackAddr(); + + // Disconnects socket sockfd. + void Disconnect(int sockfd); + + // Get family for the test. + int GetFamily(); + + // Socket used by Bind methods + FileDescriptor bind_; + + // Second socket used for tests. + FileDescriptor sock_; + + // Address for bind_ socket. + struct sockaddr* bind_addr_; + + // Initialized to the length based on GetFamily(). + socklen_t addrlen_; + + // Storage for bind_addr_. + struct sockaddr_storage bind_addr_storage_; + + private: + // Helper to initialize addrlen_ for the test case. + socklen_t GetAddrLength(); +}; + +// Gets a pointer to the port component of the given address. +uint16_t* Port(struct sockaddr_storage* addr) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + return &sin->sin_port; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + return &sin6->sin6_port; + } + } + + return nullptr; +} + +// Sets addr port to "port". +void SetPort(struct sockaddr_storage* addr, uint16_t port) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + sin->sin_port = port; + break; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + sin6->sin6_port = port; + break; + } + } +} + +void UdpSocketTest::SetUp() { + addrlen_ = GetAddrLength(); + + bind_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); + bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + + sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); +} + +int UdpSocketTest::GetFamily() { + if (GetParam() == AddressFamily::kIpv4) { + return AF_INET; + } + return AF_INET6; +} + +PosixError UdpSocketTest::BindLoopback() { + bind_addr_storage_ = InetLoopbackAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindAny() { + bind_addr_storage_ = InetAnyAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { + socklen_t len = sizeof(bind_addr_storage_); + + // Bind, then check that we get the right address. + RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); + + RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); + + if (addrlen_ != len) { + return PosixError( + EINVAL, + absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); + } + return PosixError(0); +} + +socklen_t UdpSocketTest::GetAddrLength() { + struct sockaddr_storage addr; + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + return sizeof(*sin); + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + return sizeof(*sin6); +} + +sockaddr_storage UdpSocketTest::InetAnyAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + sin->sin_port = htons(0); + return addr; + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = IN6ADDR_ANY_INIT; + sin6->sin6_port = htons(0); + return addr; +} + +sockaddr_storage UdpSocketTest::InetLoopbackAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + sin->sin_port = htons(0); + return addr; + } + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = in6addr_loopback; + sin6->sin6_port = htons(0); + return addr; +} + +void UdpSocketTest::Disconnect(int sockfd) { + sockaddr_storage addr_storage = InetAnyAddr(); + sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + socklen_t addrlen = sizeof(addr_storage); + + addr->sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, Creation) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); +} + +TEST_P(UdpSocketTest, Getsockname) { + // Check that we're not bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), + 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Getpeername) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that we're not connected. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Connect, then check that we get the right address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, SendNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Do send & write, they must fail. + char buf[512]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(EDESTADDRREQ)); + + EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EDESTADDRREQ)); + + // Use sendto. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ConnectBinds) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ReceiveNotBound) { + char buf[512]; + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, Bind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EINVAL)); + + // Check that we're still bound to the original address. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, BindInUse) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(UdpSocketTest, ReceiveAfterConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send from sock_ to bind_ + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + for (int i = 0; i < 2; i++) { + // Connet sock_ to bound address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + // Send from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect sock_. + struct sockaddr unspec = {}; + unspec.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), + SyscallSucceeds()); + } +} + +TEST_P(UdpSocketTest, Connect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to bind after connect. + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallFailsWithErrno(EINVAL)); + + struct sockaddr_storage bind2_storage = InetLoopbackAddr(); + struct sockaddr* bind2_addr = + reinterpret_cast<struct sockaddr*>(&bind2_storage); + FileDescriptor bind2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); + + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); + + // Check that peer name changed. + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); +} + +TEST_P(UdpSocketTest, ConnectAnyZero) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind to the next port above bind_. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage unspec = {}; + unspec.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), + sizeof(unspec.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(unspec); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { + ASSERT_NO_ERRNO(BindAny()); + + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + socklen_t addrlen = sizeof(addr); + + // Connect the socket. + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); + + // If the socket is bound to ANY and connected to a loopback address, + // getsockname() has to return the loopback address. + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + Disconnect(sock_.get()); + + // Check that we're still bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, any, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), + sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), *Port(&any_storage)); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = GetFamily(); + ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage addr_storage = InetAnyAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send to a different destination than we're connected to. + char buf[512]; + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + // Connect to loopback:bind_addr_+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Set sock to non-blocking. + int opts = 0; + ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallFailsWithErrno(EAGAIN)); +} + +TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, SendAndReceiveConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveFromNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that the data isn't received because it was sent from a different + // address than we're connected. + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveBeforeConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Receive the data. It works because it was sent before the connect. + char received[sizeof(buf)]; + EXPECT_THAT( + RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Send again. This time it should not be received. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveFrom) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data and sender address. + char received[sizeof(buf)]; + struct sockaddr_storage addr2; + socklen_t addr2len = sizeof(addr2); + EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, + reinterpret_cast<sockaddr*>(&addr2), &addr2len), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + EXPECT_EQ(addr2len, addrlen_); + EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Listen) { + ASSERT_THAT(listen(sock_.get(), SOMAXCONN), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +TEST_P(UdpSocketTest, Accept) { + ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +// This test validates that a read shutdown with pending data allows the read +// to proceed with the data before returning EAGAIN. +TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Verify that we get EWOULDBLOCK when there is nothing to read. + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + const char* buf = "abc"; + EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); + + int opts = 0; + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_NE(opts & O_NONBLOCK, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // We should get the data even though read has been shutdown. + EXPECT_THAT( + RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/), + IsPosixErrorOkAndHolds(2)); + + // Because we read less than the entire packet length, since it's a packet + // based socket any subsequent reads should return EWOULDBLOCK. + EXPECT_THAT(recv(bind_.get(), received, 1, 0), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +// This test is validating that even after a socket is shutdown if it's +// reconnected it will reset the shutdown state. +TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReadShutdown) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then shutdown from another thread. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + }); + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); + t.Join(); + + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, WriteShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SynchronousReceive) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_ from another thread. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + // Receive the data prior to actually starting the other thread. + char received[512]; + EXPECT_THAT( + RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Start the thread. + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, + this->addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + }); + + EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(512)); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + } + + // Receive the data as 3 separate packets. + char received[6 * psize]; + for (int i = 0; i < 3; ++i) { + EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), + SyscallSucceedsWithValue(psize)); + } + EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Direct writes from sock to bind_. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(writev(sock_.get(), iov, 2), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(readv(bind_.get(), iov, 3), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_name = bind_addr_; + msg.msg_namelen = addrlen_; + msg.msg_iov = iov; + msg.msg_iovlen = 2; + ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 3; + ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, FIONREADShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + + int n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, FIONREADWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), + SyscallSucceedsWithValue(sizeof(str))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); +} + +// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the +// corresponding macro and become `0x541B`. +TEST_P(UdpSocketTest, Fionread) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, psize); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(0)); + + // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet + // socket does not cause a poll event to be triggered. + if (!IsRunningWithHostinet()) { + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + } + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoNoCheck) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOn); + ASSERT_EQ(optlen, sizeof(v)); + + v = kSockOptOff; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +#ifdef __linux__ +TEST_P(UdpSocketTest, ErrorQueue) { + char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), + SyscallFailsWithErrno(EAGAIN)); +} +#endif // __linux__ + +TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoTimestamp) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + int v = 1; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); + + struct timeval tv = {}; + memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); + + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, WriteShutdownNotConnected) { + EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, TimestampIoctl) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); +} + +TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct timeval tv = {}; + ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +// Test that the timestamp accessed via SIOCGSTAMP is still accessible after +// SO_TIMESTAMP is enabled and used to retrieve a timestamp. +TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // Enable SO_TIMESTAMP and send a message. + int v = 1; + EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be a message for SO_TIMESTAMP. + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg = {}; + iovec iov = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + + // The ioctl should return the exact same values as before. + struct timeval tv2 = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_EQ(tv.tv_sec, tv2.tv_sec); + ASSERT_EQ(tv.tv_usec, tv2.tv_usec); +} + +// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on +// outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SetAndReceiveTOS) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Set socket TOS. + int sent_level = recv_level; + int sent_type = IP_TOS; + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + } + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, + sizeof(sent_tos)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_type == IPV6_TCLASS) { + cmsg_data_len = sizeof(int); + } + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = received_cmsgbuf.size(); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the +// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SendAndReceiveTOS) { + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + int recv_opt = kSockOptOn; + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, + sizeof(recv_opt)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + int sent_level = recv_level; + int sent_type = IP_TOS; + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + cmsg_data_len = sizeof(int); + } + std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + sent_msg.msg_control = &sent_cmsgbuf[0]; + sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + + // Manually add control message. + struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); + sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); + sent_cmsg->cmsg_level = sent_level; + sent_cmsg->cmsg_type = sent_type; + *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Bind bind_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + { + // Send data of size min and verify that it's received. + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + std::vector<char> received(buf.size()); + EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(received.size())); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector<char> buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + + std::vector<char> received(buf.size()); + ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(received.size())); + } +} + +// Test that receive buffer limits are enforced. +TEST_P(UdpSocketTest, RecvBufLimits) { + // Bind s_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 4. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { + // Linux doubles the value specified so just set to min * 2. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, + &rcv_buf_len), + SyscallSucceeds()); + } + + { + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + int sent = 4; + if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + sent++; + } + + for (int i = 0; i < sent - 1; i++) { + // Receive the data. + std::vector<char> received(buf.size()); + EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(received.size())); + EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); + } + + // The last receive should fail with EAGAIN as the last packet should have + // been dropped due to lack of space in the receive buffer. + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +#ifdef __linux__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(UdpSocketTest, SetSocketDetachFilter) { + // Program generated using sudo tcpdump -i lo udp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +#endif // __linux__ + +TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT( + getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc deleted file mode 100644 index 54a0594f7..000000000 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __fuchsia__ - -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/errqueue.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/udp_socket_test_cases.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(UdpSocketTest, ErrorQueue) { - char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), - SyscallFailsWithErrno(EAGAIN)); -} - -} // namespace testing -} // namespace gvisor - -#endif // __fuchsia__ diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc deleted file mode 100644 index 9cc6be4fb..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ /dev/null @@ -1,1727 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/udp_socket_test_cases.h" - -#include <arpa/inet.h> -#include <fcntl.h> -#include <netinet/in.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "absl/strings/str_format.h" -#ifndef SIOCGSTAMP -#include <linux/sockios.h> -#endif - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -// Gets a pointer to the port component of the given address. -uint16_t* Port(struct sockaddr_storage* addr) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - return &sin->sin_port; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - return &sin6->sin6_port; - } - } - - return nullptr; -} - -// Sets addr port to "port". -void SetPort(struct sockaddr_storage* addr, uint16_t port) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - sin->sin_port = port; - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - sin6->sin6_port = port; - break; - } - } -} - -void UdpSocketTest::SetUp() { - addrlen_ = GetAddrLength(); - - bind_ = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); - bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - - sock_ = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); -} - -int UdpSocketTest::GetFamily() { - if (GetParam() == AddressFamily::kIpv4) { - return AF_INET; - } - return AF_INET6; -} - -PosixError UdpSocketTest::BindLoopback() { - bind_addr_storage_ = InetLoopbackAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - return BindSocket(bind_.get(), bind_addr_); -} - -PosixError UdpSocketTest::BindAny() { - bind_addr_storage_ = InetAnyAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - return BindSocket(bind_.get(), bind_addr_); -} - -PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { - socklen_t len = sizeof(bind_addr_storage_); - - // Bind, then check that we get the right address. - RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); - - RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); - - if (addrlen_ != len) { - return PosixError( - EINVAL, - absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); - } - return PosixError(0); -} - -socklen_t UdpSocketTest::GetAddrLength() { - struct sockaddr_storage addr; - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - return sizeof(*sin); - } - - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - return sizeof(*sin6); -} - -sockaddr_storage UdpSocketTest::InetAnyAddr() { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); - - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - sin->sin_port = htons(0); - return addr; - } - - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - sin6->sin6_addr = IN6ADDR_ANY_INIT; - sin6->sin6_port = htons(0); - return addr; -} - -sockaddr_storage UdpSocketTest::InetLoopbackAddr() { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); - - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(0); - return addr; - } - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(0); - return addr; -} - -void UdpSocketTest::Disconnect(int sockfd) { - sockaddr_storage addr_storage = InetAnyAddr(); - sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - socklen_t addrlen = sizeof(addr_storage); - - addr->sa_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); - - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, Creation) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - EXPECT_THAT(close(sock.release()), SyscallSucceeds()); - - sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); - EXPECT_THAT(close(sock.release()), SyscallSucceeds()); - - ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); -} - -TEST_P(UdpSocketTest, Getsockname) { - // Check that we're not bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), - 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, Getpeername) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that we're not connected. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Connect, then check that we get the right address. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, SendNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Do send & write, they must fail. - char buf[512]; - EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); - - EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), - SyscallFailsWithErrno(EDESTADDRREQ)); - - // Use sendto. - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ConnectBinds) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ReceiveNotBound) { - char buf[512]; - EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, Bind) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Try to bind again. - EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), - SyscallFailsWithErrno(EINVAL)); - - // Check that we're still bound to the original address. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, BindInUse) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Try to bind again. - EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UdpSocketTest, ReceiveAfterConnect) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send from sock_ to bind_ - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - for (int i = 0; i < 2; i++) { - // Connet sock_ to bound address. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - // Send from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Disconnect sock_. - struct sockaddr unspec = {}; - unspec.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), - SyscallSucceeds()); - } -} - -TEST_P(UdpSocketTest, Connect) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); - - // Try to bind after connect. - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallFailsWithErrno(EINVAL)); - - struct sockaddr_storage bind2_storage = InetLoopbackAddr(); - struct sockaddr* bind2_addr = - reinterpret_cast<struct sockaddr*>(&bind2_storage); - FileDescriptor bind2 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); - - // Try to connect again. - EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); - - // Check that peer name changed. - peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); -} - -TEST_P(UdpSocketTest, ConnectAnyZero) { - // TODO(138658473): Enable when we can connect to port 0 with gVisor. - SKIP_IF(IsRunningOnGvisor()); - - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, ConnectAnyWithPort) { - ASSERT_NO_ERRNO(BindAny()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - // TODO(138658473): Enable when we can connect to port 0 with gVisor. - SKIP_IF(IsRunningOnGvisor()); - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - Disconnect(sock_.get()); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - ASSERT_NO_ERRNO(BindAny()); - EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); - - Disconnect(sock_.get()); -} - -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Bind to the next port above bind_. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage unspec = {}; - unspec.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), - sizeof(unspec.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - socklen_t addrlen = sizeof(unspec); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { - ASSERT_NO_ERRNO(BindAny()); - - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - socklen_t addrlen = sizeof(addr); - - // Connect the socket. - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); - - // If the socket is bound to ANY and connected to a loopback address, - // getsockname() has to return the loopback address. - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); - struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); - SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); - - ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - Disconnect(sock_.get()); - - // Check that we're still bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, any, addrlen), 0); - - addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, Disconnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); - SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); - - for (int i = 0; i < 2; i++) { - // Try to connect again. - EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); - - // Try to disconnect. - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), - sizeof(addr.ss_family)), - SyscallSucceeds()); - - peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), *Port(&any_storage)); - } -} - -TEST_P(UdpSocketTest, ConnectBadAddress) { - struct sockaddr addr = {}; - addr.sa_family = GetFamily(); - ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage addr_storage = InetAnyAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send to a different destination than we're connected to. - char buf[512]; - EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - // Connect to loopback:bind_addr_+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from bind_ to sock_. - ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {sock_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), - SyscallSucceedsWithValue(1)); - - // Receive the packet. - char received[3]; - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_port+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Set sock to non-blocking. - int opts = 0; - ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from bind_ to sock_. - ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {sock_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // Receive the packet. - char received[3]; - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send some data to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, SendAndReceiveConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:TestPort+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_port+2. - struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); - SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); - ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that the data isn't received because it was sent from a different - // address than we're connected. - EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Bind sock to loopback:bind_addr_port+2. - struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); - SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); - ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Connect to loopback:TestPort+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Receive the data. It works because it was sent before the connect. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Send again. This time it should not be received. - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveFrom) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:TestPort+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data and sender address. - char received[sizeof(buf)]; - struct sockaddr_storage addr2; - socklen_t addr2len = sizeof(addr2); - EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr2), &addr2len), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addr2len, addrlen_); - EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); -} - -TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(sock_.get(), SOMAXCONN), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test validates that a read shutdown with pending data allows the read -// to proceed with the data before returning EAGAIN. -TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Verify that we get EWOULDBLOCK when there is nothing to read. - char received[512]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - const char* buf = "abc"; - EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); - - int opts = 0; - ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_NE(opts & O_NONBLOCK, 0); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2)); - - // Because we read less than the entire packet length, since it's a packet - // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(bind_.get(), received, 1, 0), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// This test is validating that even after a socket is shutdown if it's -// reconnected it will reset the shutdown state. -TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { - char received[512]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReadShutdown) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - - char received[512]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - ASSERT_NO_ERRNO(BindLoopback()); - - char received[512]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - }); - EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); - t.Join(); - - EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, WriteShutdown) { - ASSERT_NO_ERRNO(BindLoopback()); - EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SynchronousReceive) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send some data to bind_ from another thread. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - // Receive the data prior to actually starting the other thread. - char received[512]; - EXPECT_THAT( - RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Start the thread. - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, - this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - }); - - EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(512)); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(psize)); - } - - // Receive the data as 3 separate packets. - char received[6 * psize]; - for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), - SyscallSucceedsWithValue(psize)); - } - EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Direct writes from sock to bind_. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send 2 packets from sock to bind_, where each packet's data consists of - // 2 discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(writev(sock_.get(), iov, 2), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(readv(bind_.get(), iov, 3), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send 2 packets from sock to bind_, where each packet's data consists of - // 2 discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_name = bind_addr_; - msg.msg_namelen = addrlen_; - msg.msg_iov = iov; - msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_iov = iov; - msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, FIONREADShutdown) { - ASSERT_NO_ERRNO(BindLoopback()); - - int n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, FIONREADWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), - SyscallSucceedsWithValue(sizeof(str))); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); -} - -// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the -// corresponding macro and become `0x541B`. -TEST_P(UdpSocketTest, Fionread) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(psize)); - - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, psize); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(0)); - - // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet - // socket does not cause a poll event to be triggered. - if (!IsRunningWithHostinet()) { - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - } - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoNoCheck) { - // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = kSockOptOn; - socklen_t optlen = sizeof(v); - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), - SyscallSucceeds()); - v = -1; - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOn); - ASSERT_EQ(optlen, sizeof(v)); - - v = kSockOptOff; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), - SyscallSucceeds()); - v = -1; - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestamp) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - int v = 1; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(0)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - struct timeval tv = {}; - memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); - - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, TimestampIoctl) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); -} - -TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct timeval tv = {}; - ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); -} - -// Test that the timestamp accessed via SIOCGSTAMP is still accessible after -// SO_TIMESTAMP is enabled and used to retrieve a timestamp. -TEST_P(UdpSocketTest, TimestampIoctlPersistence) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // Enable SO_TIMESTAMP and send a message. - int v = 1; - EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be a message for SO_TIMESTAMP. - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg = {}; - iovec iov = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(0)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - - // The ioctl should return the exact same values as before. - struct timeval tv2 = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); - ASSERT_EQ(tv.tv_sec, tv2.tv_sec); - ASSERT_EQ(tv.tv_usec, tv2.tv_usec); -} - -// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on -// outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SetAndReceiveTOS) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Set socket TOS. - int sent_level = recv_level; - int sent_type = IP_TOS; - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - } - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, - sizeof(sent_tos)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_type == IPV6_TCLASS) { - cmsg_data_len = sizeof(int); - } - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = received_cmsgbuf.size(); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the -// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - int recv_opt = kSockOptOn; - ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, - sizeof(recv_opt)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - int sent_level = recv_level; - int sent_type = IP_TOS; - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - cmsg_data_len = sizeof(int); - } - std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - sent_msg.msg_control = &sent_cmsgbuf[0]; - sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - - // Manually add control message. - struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); - sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); - sent_cmsg->cmsg_level = sent_level; - sent_cmsg->cmsg_type = sent_type; - *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; - - ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { - // Discover minimum buffer size by setting it to zero. - constexpr int kRcvBufSz = 0; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, - sizeof(kRcvBufSz)), - SyscallSucceeds()); - - int min = 0; - socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), - SyscallSucceeds()); - - // Bind bind_ to loopback. - ASSERT_NO_ERRNO(BindLoopback()); - - { - // Send data of size min and verify that it's received. - std::vector<char> buf(min); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - } - - { - // Send data of size min + 1 and verify that its received. Both linux and - // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer - // is currently empty. - std::vector<char> buf(min + 1); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - } -} - -// Test that receive buffer limits are enforced. -TEST_P(UdpSocketTest, RecvBufLimits) { - // Bind s_ to loopback. - ASSERT_NO_ERRNO(BindLoopback()); - - int min = 0; - { - // Discover minimum buffer size by trying to set it to zero. - constexpr int kRcvBufSz = 0; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, - sizeof(kRcvBufSz)), - SyscallSucceeds()); - - socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), - SyscallSucceeds()); - } - - // Now set the limit to min * 4. - int new_rcv_buf_sz = min * 4; - if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { - // Linux doubles the value specified so just set to min * 2. - new_rcv_buf_sz = min * 2; - } - - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, - sizeof(new_rcv_buf_sz)), - SyscallSucceeds()); - int rcv_buf_sz = 0; - { - socklen_t rcv_buf_len = sizeof(rcv_buf_sz); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, - &rcv_buf_len), - SyscallSucceeds()); - } - - { - std::vector<char> buf(min); - RandomizeBuffer(buf.data(), buf.size()); - - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - int sent = 4; - if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { - // Linux seems to drop the 4th packet even though technically it should - // fit in the receive buffer. - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - sent++; - } - - for (int i = 0; i < sent - 1; i++) { - // Receive the data. - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); - } - - // The last receive should fail with EAGAIN as the last packet should have - // been dropped due to lack of space in the receive buffer. - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - } -} - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h deleted file mode 100644 index f7e25c805..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ -#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ - -#include <sys/socket.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// The initial port to be be used on gvisor. -constexpr int TestPort = 40000; - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class UdpSocketTest - : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { - protected: - // Creates two sockets that will be used by test cases. - void SetUp() override; - - // Binds the socket bind_ to the loopback and updates bind_addr_. - PosixError BindLoopback(); - - // Binds the socket bind_ to Any and updates bind_addr_. - PosixError BindAny(); - - // Binds given socket to address addr and updates. - PosixError BindSocket(int socket, struct sockaddr* addr); - - // Return initialized Any address to port 0. - struct sockaddr_storage InetAnyAddr(); - - // Return initialized Loopback address to port 0. - struct sockaddr_storage InetLoopbackAddr(); - - // Disconnects socket sockfd. - void Disconnect(int sockfd); - - // Get family for the test. - int GetFamily(); - - // Socket used by Bind methods - FileDescriptor bind_; - - // Second socket used for tests. - FileDescriptor sock_; - - // Address for bind_ socket. - struct sockaddr* bind_addr_; - - // Initialized to the length based on GetFamily(). - socklen_t addrlen_; - - // Storage for bind_addr_. - struct sockaddr_storage bind_addr_storage_; - - private: - // Helper to initialize addrlen_ for the test case. - socklen_t GetAddrLength(); -}; -} // namespace testing -} // namespace gvisor - -#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc index 2040375c9..061e2e0f1 100644 --- a/test/syscalls/linux/unlink.cc +++ b/test/syscalls/linux/unlink.cc @@ -208,6 +208,20 @@ TEST(RmdirTest, CanRemoveWithTrailingSlashes) { ASSERT_THAT(rmdir(slashslash.c_str()), SyscallSucceeds()); } +TEST(UnlinkTest, UnlinkAtEmptyPath) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + EXPECT_THAT(unlinkat(fd.get(), "", 0), SyscallFailsWithErrno(ENOENT)); + + auto dirInDir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); + auto dirFD = ASSERT_NO_ERRNO_AND_VALUE( + Open(dirInDir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(unlinkat(dirFD.get(), "", AT_REMOVEDIR), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc index ce1899f45..2a8699a7b 100644 --- a/test/syscalls/linux/vdso_clock_gettime.cc +++ b/test/syscalls/linux/vdso_clock_gettime.cc @@ -38,8 +38,6 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { switch (info.param) { case CLOCK_MONOTONIC: return "CLOCK_MONOTONIC"; - case CLOCK_REALTIME: - return "CLOCK_REALTIME"; case CLOCK_BOOTTIME: return "CLOCK_BOOTTIME"; default: @@ -47,59 +45,36 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { } } -class CorrectVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; +class MonotonicVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; -TEST_P(CorrectVDSOClockTest, IsCorrect) { +TEST_P(MonotonicVDSOClockTest, IsCorrect) { + // The VDSO implementation of clock_gettime() uses the TSC. On KVM, sentry and + // application TSCs can be very desynchronized; see + // sentry/platform/kvm/kvm.vCPU.setSystemTime(). + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + // Check that when we alternate readings from the clock_gettime syscall and + // the VDSO's implementation, we observe the combined sequence as being + // monotonic. struct timespec tvdso, tsys; absl::Time vdso_time, sys_time; - uint64_t total_calls = 0; - - // It is expected that 82.5% of clock_gettime calls will be less than 100us - // skewed from the system time. - // Unfortunately this is not only influenced by the VDSO clock skew, but also - // by arbitrary scheduling delays and the like. The test is therefore - // regularly disabled. - std::map<absl::Duration, std::tuple<double, uint64_t, uint64_t>> confidence = - { - {absl::Microseconds(100), std::make_tuple(0.825, 0, 0)}, - {absl::Microseconds(250), std::make_tuple(0.94, 0, 0)}, - {absl::Milliseconds(1), std::make_tuple(0.999, 0, 0)}, - }; - - absl::Time start = absl::Now(); - while (absl::Now() < start + absl::Seconds(30)) { - EXPECT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); - EXPECT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), - SyscallSucceeds()); - + ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), + SyscallSucceeds()); + sys_time = absl::TimeFromTimespec(tsys); + auto end = absl::Now() + absl::Seconds(10); + while (absl::Now() < end) { + ASSERT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); vdso_time = absl::TimeFromTimespec(tvdso); - - for (auto const& conf : confidence) { - std::get<1>(confidence[conf.first]) += - (sys_time - vdso_time) < conf.first; - } - + EXPECT_LE(sys_time, vdso_time); + ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), + SyscallSucceeds()); sys_time = absl::TimeFromTimespec(tsys); - - for (auto const& conf : confidence) { - std::get<2>(confidence[conf.first]) += - (vdso_time - sys_time) < conf.first; - } - - ++total_calls; - } - - for (auto const& conf : confidence) { - EXPECT_GE(std::get<1>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); - EXPECT_GE(std::get<2>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); + EXPECT_LE(vdso_time, sys_time); } } -INSTANTIATE_TEST_SUITE_P(ClockGettime, CorrectVDSOClockTest, - ::testing::Values(CLOCK_MONOTONIC, CLOCK_REALTIME, - CLOCK_BOOTTIME), +INSTANTIATE_TEST_SUITE_P(ClockGettime, MonotonicVDSOClockTest, + ::testing::Values(CLOCK_MONOTONIC, CLOCK_BOOTTIME), PrintClockId); } // namespace diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 39b5b2f56..77bcfbb8a 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -133,6 +133,91 @@ TEST_F(WriteTest, WriteExceedsRLimit) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(WriteTest, WriteIncrementOffset) { + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 0), SyscallSucceedsWithValue(0)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); + + const int bytes_total = 1024; + + EXPECT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); +} + +TEST_F(WriteTest, WriteIncrementOffsetSeek) { + const std::string data = "hello world\n"; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + const int seek_offset = data.size() / 2; + ASSERT_THAT(lseek(fd, seek_offset, SEEK_SET), + SyscallSucceedsWithValue(seek_offset)); + + const int write_bytes = 512; + EXPECT_THAT(WriteBytes(fd, write_bytes), + SyscallSucceedsWithValue(write_bytes)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(seek_offset + write_bytes)); +} + +TEST_F(WriteTest, WriteIncrementOffsetAppend) { + const std::string data = "hello world\n"; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpfile.path().c_str(), O_WRONLY | O_APPEND)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(data.size() + 1024)); +} + +TEST_F(WriteTest, WriteIncrementOffsetEOF) { + const std::string data = "hello world\n"; + const TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + EXPECT_THAT(lseek(fd, 0, SEEK_END), SyscallSucceedsWithValue(data.size())); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_END), + SyscallSucceedsWithValue(data.size() + 1024)); +} + +TEST_F(WriteTest, PwriteNoChangeOffset) { + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + const std::string data = "hello world\n"; + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); + + const int bytes_total = 1024; + ASSERT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); + ASSERT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), bytes_total), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index cbcf08451..bd3f829c4 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -28,6 +28,7 @@ #include "test/syscalls/linux/file_base.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -37,6 +38,8 @@ namespace testing { namespace { +using ::gvisor::testing::IsTmpfs; + class XattrTest : public FileTest {}; TEST_F(XattrTest, XattrNonexistentFile) { @@ -229,7 +232,7 @@ TEST_F(XattrTest, XattrOnInvalidFileTypes) { EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); } -TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { +TEST_F(XattrTest, SetXattrSizeSmallerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -244,7 +247,7 @@ TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrZeroSize) { +TEST_F(XattrTest, SetXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -256,7 +259,7 @@ TEST_F(XattrTest, SetxattrZeroSize) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, SetxattrSizeTooLarge) { +TEST_F(XattrTest, SetXattrSizeTooLarge) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; @@ -271,7 +274,7 @@ TEST_F(XattrTest, SetxattrSizeTooLarge) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { +TEST_F(XattrTest, SetXattrNullValueAndNonzeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0), @@ -280,7 +283,7 @@ TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { +TEST_F(XattrTest, SetXattrNullValueAndZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -288,7 +291,7 @@ TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { +TEST_F(XattrTest, SetXattrValueTooLargeButOKSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val(XATTR_SIZE_MAX + 1); @@ -304,7 +307,7 @@ TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrReplaceWithSmaller) { +TEST_F(XattrTest, SetXattrReplaceWithSmaller) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -319,7 +322,7 @@ TEST_F(XattrTest, SetxattrReplaceWithSmaller) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrReplaceWithLarger) { +TEST_F(XattrTest, SetXattrReplaceWithLarger) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -333,7 +336,7 @@ TEST_F(XattrTest, SetxattrReplaceWithLarger) { EXPECT_EQ(buf, val); } -TEST_F(XattrTest, SetxattrCreateFlag) { +TEST_F(XattrTest, SetXattrCreateFlag) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), @@ -344,7 +347,7 @@ TEST_F(XattrTest, SetxattrCreateFlag) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrReplaceFlag) { +TEST_F(XattrTest, SetXattrReplaceFlag) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), @@ -356,14 +359,14 @@ TEST_F(XattrTest, SetxattrReplaceFlag) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrInvalidFlags) { +TEST_F(XattrTest, SetXattrInvalidFlags) { const char* path = test_file_name_.c_str(); int invalid_flags = 0xff; EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags), SyscallFailsWithErrno(EINVAL)); } -TEST_F(XattrTest, Getxattr) { +TEST_F(XattrTest, GetXattr) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; int val = 1234; @@ -375,7 +378,7 @@ TEST_F(XattrTest, Getxattr) { EXPECT_EQ(buf, val); } -TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { +TEST_F(XattrTest, GetXattrSizeSmallerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -387,7 +390,7 @@ TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, GetxattrSizeLargerThanValue) { +TEST_F(XattrTest, GetXattrSizeLargerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -402,7 +405,7 @@ TEST_F(XattrTest, GetxattrSizeLargerThanValue) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, GetxattrZeroSize) { +TEST_F(XattrTest, GetXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -415,7 +418,7 @@ TEST_F(XattrTest, GetxattrZeroSize) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, GetxattrSizeTooLarge) { +TEST_F(XattrTest, GetXattrSizeTooLarge) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -431,7 +434,7 @@ TEST_F(XattrTest, GetxattrSizeTooLarge) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, GetxattrNullValue) { +TEST_F(XattrTest, GetXattrNullValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -442,7 +445,7 @@ TEST_F(XattrTest, GetxattrNullValue) { SyscallFailsWithErrno(EFAULT)); } -TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { +TEST_F(XattrTest, GetXattrNullValueAndZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -458,13 +461,13 @@ TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size)); } -TEST_F(XattrTest, GetxattrNonexistentName) { +TEST_F(XattrTest, GetXattrNonexistentName) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, Listxattr) { +TEST_F(XattrTest, ListXattr) { const char* path = test_file_name_.c_str(); const std::string name = "user.test"; const std::string name2 = "user.test2"; @@ -490,7 +493,7 @@ TEST_F(XattrTest, Listxattr) { EXPECT_EQ(got, expected); } -TEST_F(XattrTest, ListxattrNoXattrs) { +TEST_F(XattrTest, ListXattrNoXattrs) { const char* path = test_file_name_.c_str(); std::vector<char> list, expected; @@ -498,13 +501,13 @@ TEST_F(XattrTest, ListxattrNoXattrs) { SyscallSucceedsWithValue(0)); EXPECT_EQ(list, expected); - // Listxattr should succeed if there are no attributes, even if the buffer + // ListXattr should succeed if there are no attributes, even if the buffer // passed in is a nullptr. EXPECT_THAT(listxattr(path, nullptr, sizeof(list)), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, ListxattrNullBuffer) { +TEST_F(XattrTest, ListXattrNullBuffer) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -513,7 +516,7 @@ TEST_F(XattrTest, ListxattrNullBuffer) { SyscallFailsWithErrno(EFAULT)); } -TEST_F(XattrTest, ListxattrSizeTooSmall) { +TEST_F(XattrTest, ListXattrSizeTooSmall) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -523,7 +526,7 @@ TEST_F(XattrTest, ListxattrSizeTooSmall) { SyscallFailsWithErrno(ERANGE)); } -TEST_F(XattrTest, ListxattrZeroSize) { +TEST_F(XattrTest, ListXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -604,6 +607,83 @@ TEST_F(XattrTest, XattrWithFD) { EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); } +TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { + // Trusted namespace not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.test"; + + // Writing to the trusted.* xattr namespace requires CAP_SYS_ADMIN in the root + // user namespace. There's no easy way to check that, other than trying the + // operation and seeing what happens. We'll call removexattr because it's + // simplest. + if (removexattr(path, name) < 0) { + SKIP_IF(errno == EPERM); + FAIL() << "unexpected errno from removexattr: " << errno; + } + + // Set. + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + // Get. + char got = '\0'; + EXPECT_THAT(getxattr(path, name, &got, size), SyscallSucceedsWithValue(size)); + EXPECT_EQ(val, got); + + // List. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + // Remove. + EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); + + // Get should now return ENODATA. + EXPECT_THAT(getxattr(path, name, &got, size), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, TrustedNamespaceWithoutCapSysAdmin) { + // Trusted namespace not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); + + // Drop CAP_SYS_ADMIN if we have it. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { + EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); + } + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.test"; + + // Set fails. + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + + // Get fails. + char got = '\0'; + EXPECT_THAT(getxattr(path, name, &got, size), SyscallFailsWithErrno(ENODATA)); + + // List still works, but returns no items. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), SyscallSucceedsWithValue(0)); + + // Remove fails. + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); +} + } // namespace } // namespace testing diff --git a/test/util/BUILD b/test/util/BUILD index 2a17c33ee..26c2b6a2f 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -46,6 +46,13 @@ cc_library( ) cc_library( + name = "fuse_util", + testonly = 1, + srcs = ["fuse_util.cc"], + hdrs = ["fuse_util.h"], +) + +cc_library( name = "proc_util", testonly = 1, srcs = ["proc_util.cc"], @@ -247,7 +254,7 @@ cc_library( ], hdrs = ["test_util.h"], defines = select_system(), - deps = [ + deps = coreutil() + [ ":fs_util", ":logging", ":posix_error", diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 052781445..b16055dd8 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -15,7 +15,11 @@ #include "test/util/fs_util.h" #include <dirent.h> +#ifdef __linux__ +#include <linux/magic.h> +#endif // __linux__ #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -125,12 +129,12 @@ PosixErrorOr<struct stat> Fstat(int fd) { PosixErrorOr<bool> Exists(absl::string_view path) { struct stat stat_buf; - int res = stat(std::string(path).c_str(), &stat_buf); + int res = lstat(std::string(path).c_str(), &stat_buf); if (res < 0) { if (errno == ENOENT) { return false; } - return PosixError(errno, absl::StrCat("stat ", path)); + return PosixError(errno, absl::StrCat("lstat ", path)); } return true; } @@ -629,5 +633,35 @@ PosixErrorOr<std::string> ProcessExePath(int pid) { return ReadLink(absl::StrCat("/proc/", pid, "/exe")); } +#ifdef __linux__ +PosixErrorOr<bool> IsTmpfs(const std::string& path) { + struct statfs stat; + if (statfs(path.c_str(), &stat)) { + if (errno == ENOENT) { + // Nothing at path, don't raise this as an error. Instead, just report no + // tmpfs at path. + return false; + } + return PosixError(errno, + absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); + } + return stat.f_type == TMPFS_MAGIC; +} +#endif // __linux__ + +PosixErrorOr<bool> IsOverlayfs(const std::string& path) { + struct statfs stat; + if (statfs(path.c_str(), &stat)) { + if (errno == ENOENT) { + // Nothing at path, don't raise this as an error. Instead, just report no + // overlayfs at path. + return false; + } + return PosixError(errno, + absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); + } + return stat.f_type == OVERLAYFS_SUPER_MAGIC; +} + } // namespace testing } // namespace gvisor diff --git a/test/util/fs_util.h b/test/util/fs_util.h index caf19b24d..c99cf5eb7 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -17,6 +17,7 @@ #include <dirent.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -37,6 +38,10 @@ constexpr int kOLargeFile = 00400000; #error "Unknown architecture" #endif +// From linux/magic.h. For some reason, not defined in the headers for some +// build environments. +#define OVERLAYFS_SUPER_MAGIC 0x794c7630 + // Returns a status or the current working directory. PosixErrorOr<std::string> GetCWD(); @@ -44,9 +49,14 @@ PosixErrorOr<std::string> GetCWD(); // can't be determined. PosixErrorOr<bool> Exists(absl::string_view path); -// Returns a stat structure for the given path or an error. +// Returns a stat structure for the given path or an error. If the path +// represents a symlink, it will be traversed. PosixErrorOr<struct stat> Stat(absl::string_view path); +// Returns a stat structure for the given path or an error. If the path +// represents a symlink, it will not be traversed. +PosixErrorOr<struct stat> Lstat(absl::string_view path); + // Returns a stat struct for the given fd. PosixErrorOr<struct stat> Fstat(int fd); @@ -173,6 +183,14 @@ std::string CleanPath(absl::string_view path); // Returns the full path to the executable of the given pid or a PosixError. PosixErrorOr<std::string> ProcessExePath(int pid); +#ifdef __linux__ +// IsTmpfs returns true if the file at path is backed by tmpfs. +PosixErrorOr<bool> IsTmpfs(const std::string& path); +#endif // __linux__ + +// IsOverlayfs returns true if the file at path is backed by overlayfs. +PosixErrorOr<bool> IsOverlayfs(const std::string& path); + namespace internal { // Not part of the public API. std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); diff --git a/test/util/fuse_util.cc b/test/util/fuse_util.cc new file mode 100644 index 000000000..027f8386c --- /dev/null +++ b/test/util/fuse_util.cc @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/fuse_util.h" + +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> + +namespace gvisor { +namespace testing { + +// Create a default FuseAttr struct with specified mode, inode, and size. +fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size) { + const int time_sec = 1595436289; + const int time_nsec = 134150844; + return (struct fuse_attr){ + .ino = inode, + .size = size, + .blocks = 4, + .atime = time_sec, + .mtime = time_sec, + .ctime = time_sec, + .atimensec = time_nsec, + .mtimensec = time_nsec, + .ctimensec = time_nsec, + .mode = mode, + .nlink = 2, + .uid = 1234, + .gid = 4321, + .rdev = 12, + .blksize = 4096, + }; +} + +// Create response body with specified mode, nodeID, and size. +fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, uint64_t size) { + struct fuse_entry_out default_entry_out = { + .nodeid = node_id, + .generation = 0, + .entry_valid = 0, + .attr_valid = 0, + .entry_valid_nsec = 0, + .attr_valid_nsec = 0, + .attr = DefaultFuseAttr(mode, node_id, size), + }; + return default_entry_out; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/fuse_util.h b/test/util/fuse_util.h new file mode 100644 index 000000000..544fe1b38 --- /dev/null +++ b/test/util/fuse_util.h @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_FUSE_UTIL_H_ +#define GVISOR_TEST_UTIL_FUSE_UTIL_H_ + +#include <linux/fuse.h> +#include <sys/uio.h> + +#include <string> +#include <vector> + +namespace gvisor { +namespace testing { + +// The fundamental generation function with a single argument. If passed by +// std::string or std::vector<char>, it will call specialized versions as +// implemented below. +template <typename T> +std::vector<struct iovec> FuseGenerateIovecs(T &first) { + return {(struct iovec){.iov_base = &first, .iov_len = sizeof(first)}}; +} + +// If an argument is of type std::string, it must be used in read-only scenario. +// Because we are setting up iovec, which contains the original address of a +// data structure, we have to drop const qualification. Usually used with +// variable-length payload data. +template <typename T = std::string> +std::vector<struct iovec> FuseGenerateIovecs(std::string &first) { + // Pad one byte for null-terminate c-string. + return {(struct iovec){.iov_base = const_cast<char *>(first.c_str()), + .iov_len = first.size() + 1}}; +} + +// If an argument is of type std::vector<char>, it must be used in write-only +// scenario and the size of the variable must be greater than or equal to the +// size of the expected data. Usually used with variable-length payload data. +template <typename T = std::vector<char>> +std::vector<struct iovec> FuseGenerateIovecs(std::vector<char> &first) { + return {(struct iovec){.iov_base = first.data(), .iov_len = first.size()}}; +} + +// A helper function to set up an array of iovec struct for testing purpose. +// Use variadic class template to generalize different numbers and different +// types of FUSE structs. +template <typename T, typename... Types> +std::vector<struct iovec> FuseGenerateIovecs(T &first, Types &...args) { + auto first_iovec = FuseGenerateIovecs(first); + auto iovecs = FuseGenerateIovecs(args...); + first_iovec.insert(std::end(first_iovec), std::begin(iovecs), + std::end(iovecs)); + return first_iovec; +} + +// Create a fuse_attr filled with the specified mode and inode. +fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size = 512); + +// Return a fuse_entry_out FUSE server response body. +fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, + uint64_t size = 512); + +} // namespace testing +} // namespace gvisor +#endif // GVISOR_TEST_UTIL_FUSE_UTIL_H_ diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c01f916aa..2cf0bea74 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -23,15 +23,15 @@ namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - PosixErrorOr<int> n = SlaveID(master); +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master) { + PosixErrorOr<int> n = ReplicaID(master); if (!n.ok()) { return PosixErrorOr<FileDescriptor>(n.error()); } return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); } -PosixErrorOr<int> SlaveID(const FileDescriptor& master) { +PosixErrorOr<int> ReplicaID(const FileDescriptor& master) { // Get pty index. int n; int ret = ioctl(master.get(), TIOCGPTN, &n); diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 0722da379..ed7658868 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -21,11 +21,11 @@ namespace gvisor { namespace testing { -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Opens the replica end of the passed master as R/W and nonblocking. +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master); -// Get the number of the slave end of the master. -PosixErrorOr<int> SlaveID(const FileDescriptor& master); +// Get the number of the replica end of the master. +PosixErrorOr<int> ReplicaID(const FileDescriptor& master); } // namespace testing } // namespace gvisor diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc index 9c10b6674..e1bdee7fd 100644 --- a/test/util/temp_path.cc +++ b/test/util/temp_path.cc @@ -56,7 +56,7 @@ void TryDeleteRecursively(std::string const& path) { if (undeleted_dirs || undeleted_files || !status.ok()) { std::cerr << path << ": failed to delete " << undeleted_dirs << " directories and " << undeleted_files - << " files: " << status; + << " files: " << status << std::endl; } } } diff --git a/test/util/test_util.cc b/test/util/test_util.cc index 8a037f45f..d0c1d6426 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -42,6 +42,7 @@ namespace testing { constexpr char kGvisorNetwork[] = "GVISOR_NETWORK"; constexpr char kGvisorVfs[] = "GVISOR_VFS"; +constexpr char kFuseEnabled[] = "FUSE_ENABLED"; bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; } @@ -68,6 +69,11 @@ bool IsRunningWithVFS1() { return strcmp(env, "VFS1") == 0; } +bool IsFUSEEnabled() { + const char* env = getenv(kFuseEnabled); + return env && strcmp(env, "TRUE") == 0; +} + // Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations // %ebx contains the address of the global offset table. %rbx is occasionally // used to address stack variables in presence of dynamic allocas. diff --git a/test/util/test_util.h b/test/util/test_util.h index 109078fc7..373c54f32 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -225,6 +225,7 @@ const std::string GvisorPlatform(); bool IsRunningWithHostinet(); // TODO(gvisor.dev/issue/1624): Delete once VFS1 is gone. bool IsRunningWithVFS1(); +bool IsFUSEEnabled(); #ifdef __linux__ void SetupGvisorDeathTest(); @@ -567,6 +568,25 @@ ssize_t ApplyFileIoSyscall(F const& f, size_t const count) { } // namespace internal +inline PosixErrorOr<std::string> ReadAllFd(int fd) { + std::string all; + all.reserve(128 * 1024); // arbitrary. + + std::vector<char> buffer(16 * 1024); + for (;;) { + auto const bytes = RetryEINTR(read)(fd, buffer.data(), buffer.size()); + if (bytes < 0) { + return PosixError(errno, "file read"); + } + if (bytes == 0) { + return std::move(all); + } + if (bytes > 0) { + all.append(buffer.data(), bytes); + } + } +} + inline ssize_t ReadFd(int fd, void* buf, size_t count) { return internal::ApplyFileIoSyscall( [&](size_t completed) { diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc index 694d21692..7210094eb 100644 --- a/test/util/test_util_runfiles.cc +++ b/test/util/test_util_runfiles.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef __fuchsia__ - #include <iostream> #include <string> @@ -46,5 +44,3 @@ std::string RunfilePath(std::string path) { } // namespace testing } // namespace gvisor - -#endif // __fuchsia__ diff --git a/tools/BUILD b/tools/BUILD index 34b950644..da83877b1 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -1 +1,9 @@ +load("//tools:defs.bzl", "bzl_library") + package(licenses = ["notice"]) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/bazel.mk b/tools/bazel.mk index 9f4a40669..5e129b2ed 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -15,29 +15,54 @@ # limitations under the License. # See base Makefile. +SHELL=/bin/bash -o pipefail BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ xargs -n 1 basename 2>/dev/null) +BUILD_ROOT := $(CURDIR)/bazel-bin/ # Bazel container configuration (see below). USER ?= gvisor HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) +BUILDER_BASE := gvisor.dev/images/default +BUILDER_IMAGE := gvisor.dev/images/builder +BUILDER_NAME ?= gvisor-builder-$(HASH) DOCKER_NAME ?= gvisor-bazel-$(HASH) DOCKER_PRIVILEGED ?= --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock +DOCKER_CONFIG := /etc/docker/daemon.json + +# Bazel flags. +BAZEL := bazel $(STARTUP_OPTIONS) +OPTIONS += --color=no --curses=no -# Non-configurable. +# Basic options. UID := $(shell id -u ${USER}) GID := $(shell id -g ${USER}) USERADD_OPTIONS := FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS) +FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID) +FULL_DOCKER_RUN_OPTIONS += --entrypoint "" +FULL_DOCKER_RUN_OPTIONS += --init FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" +FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID) +FULL_DOCKER_EXEC_OPTIONS += --interactive +ifeq (true,$(shell [[ -t 0 ]] && echo true)) +FULL_DOCKER_EXEC_OPTIONS += --tty +endif + +# Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" +# TODO(gvisor.dev/issue/1624): Remove docker config volume. This is required +# temporarily for checking VFS1 vs VFS2 by some tests. +FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)" +FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED) +FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED) DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET)) ifneq ($(GID),$(DOCKER_GROUP)) USERADD_OPTIONS += --groups $(DOCKER_GROUP) @@ -45,7 +70,40 @@ GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) && FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP) endif endif -SHELL=/bin/bash -o pipefail + +# Add KVM passthrough options. +ifneq (,$(wildcard /dev/kvm)) +FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm +KVM_GROUP := $(shell stat -c '%g' /dev/kvm) +ifneq ($(GID),$(KVM_GROUP)) +USERADD_OPTIONS += --groups $(KVM_GROUP) +GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) && +FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP) +endif +endif + +# Load the appropriate config. +ifneq (,$(BAZEL_CONFIG)) +OPTIONS += --config=$(BAZEL_CONFIG) +endif + +# NOTE: we pass -l to useradd below because otherwise you can hit a bug +# best described here: +# https://github.com/moby/moby/issues/5419#issuecomment-193876183 +# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out +# out of disk space. +bazel-image: load-default + @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi + docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \ + $(BUILDER_BASE) \ + sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ + $(GROUPADD_DOCKER) \ + useradd -l --uid $(UID) --non-unique --no-create-home \ + --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ + if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi" + docker commit $(BUILDER_NAME) $(BUILDER_IMAGE) + @docker rm -f $(BUILDER_NAME) +.PHONY: bazel-image ## ## Bazel helpers. @@ -60,45 +118,43 @@ SHELL=/bin/bash -o pipefail ## GCLOUD_CONFIG - The gcloud config directory (detect: detected). ## DOCKER_SOCKET - The Docker socket (default: detected). ## -bazel-server-start: load-default ## Starts the bazel server. +bazel-server-start: bazel-image ## Starts the bazel server. @mkdir -p $(BAZEL_CACHE) @mkdir -p $(GCLOUD_CONFIG) - docker run -d --rm \ - --init \ - --name $(DOCKER_NAME) \ - --user 0:0 $(DOCKER_GROUP_OPTIONS) \ + @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi + # This command runs a bazel server, and the container sticks around + # until the bazel server exits. This should ensure that it does not + # exit in the middle of running a build, but also it won't stick around + # forever. The build commands wrap around an appropriate exec into the + # container in order to perform work via the bazel client. + docker run -d --rm --name $(DOCKER_NAME) \ -v "$(CURDIR):$(CURDIR)" \ --workdir "$(CURDIR)" \ - --entrypoint "" \ $(FULL_DOCKER_RUN_OPTIONS) \ - gvisor.dev/images/default \ - sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ - $(GROUPADD_DOCKER) \ - useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ - bazel version && \ - exec tail --pid=\$$(bazel info server_pid) -f /dev/null" - @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; \ - if ! docker ps | grep $(DOCKER_NAME); then exit 1; else sleep 1; fi; done + $(BUILDER_IMAGE) \ + sh -c "tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null" .PHONY: bazel-server-start bazel-shutdown: ## Shuts down a running bazel server. - @docker exec --user $(UID):$(GID) $(DOCKER_NAME) bazel shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \ + rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] .PHONY: bazel-shutdown bazel-alias: ## Emits an alias that can be used within the shell. - @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'" + @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'" .PHONY: bazel-alias bazel-server: ## Ensures that the server exists. Used as an internal target. - @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true || $(MAKE) bazel-server-start .PHONY: bazel-server -build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c 'bazel $(STARTUP_OPTIONS) build $(OPTIONS) $(TARGETS)' +build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) "$(TARGETS)"' build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ | grep -E "^ bazel-bin/" \ - | awk "{print $$1;}" \ + | tr -d '\r' \ + | awk '{$$1=$$1};1' \ | xargs -n 1 -I {} sh -c "$(1)" build: bazel-server @@ -109,7 +165,7 @@ copy: bazel-server ifeq (,$(DESTINATION)) $(error Destination not provided.) endif - @$(call build_paths,cp -a {} $(DESTINATION)) + @$(call build_paths,cp -fa {} $(DESTINATION)) run: bazel-server @$(call build_paths,{} $(ARGS)) @@ -119,6 +175,12 @@ sudo: bazel-server @$(call build_paths,sudo -E {} $(ARGS)) .PHONY: sudo +test: OPTIONS += --test_output=errors --keep_going --verbose_failures=true test: bazel-server - @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel $(STARTUP_OPTIONS) test $(OPTIONS) $(TARGETS) + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS) .PHONY: test + +query: + @$(MAKE) bazel-server >&2 # If we need to start, ensure stdout is not polluted. + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) query $(OPTIONS) "$(TARGETS)" 2>/dev/null' +.PHONY: query diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD index f2f80bae1..8d4356119 100644 --- a/tools/bazeldefs/BUILD +++ b/tools/bazeldefs/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "rbe_platform", "rbe_toolchain") +load("//tools:defs.bzl", "bzl_library", "rbe_platform", "rbe_toolchain") package(licenses = ["notice"]) @@ -49,3 +49,58 @@ rbe_toolchain( toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", ) + +# Updated versions of the above, compatible with bazel3. +rbe_platform( + name = "rbe_ubuntu1604_bazel3", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains_bazel3//constraints:xenial", + "@bazel_toolchains_bazel3//constraints/sanitizers:support_msan", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272" + } + properties: { + name: "dockerAddCapabilities" + value: "SYS_ADMIN" + } + properties: { + name: "dockerPrivileged" + value: "true" + } + """, +) + +rbe_toolchain( + name = "cc-toolchain-clang-x86_64-default_bazel3", + exec_compatible_with = [], + tags = [ + "manual", + ], + target_compatible_with = [], + toolchain = "@bazel_toolchains_bazel3//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +bzl_library( + name = "platforms_bzl", + srcs = ["platforms.bzl"], + visibility = ["//visibility:private"], +) + +bzl_library( + name = "tags_bzl", + srcs = ["tags.bzl"], + visibility = ["//visibility:private"], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 620c460de..cf5b1dc0d 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -2,15 +2,16 @@ load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle") load("@bazel_skylib//rules:build_test.bzl", _build_test = "build_test") +load("@bazel_skylib//:bzl_library.bzl", _bzl_library = "bzl_library") load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test") load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library") load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") -load("@pydeps//:requirements.bzl", _py_requirement = "requirement") load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library") build_test = _build_test +bzl_library = _bzl_library cc_library = _cc_library cc_flags_supplier = _cc_flags_supplier cc_proto_library = _cc_proto_library @@ -25,13 +26,14 @@ gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" pkg_deb = _pkg_deb pkg_tar = _pkg_tar -py_library = native.py_library py_binary = native.py_binary -py_test = native.py_test rbe_platform = native.platform rbe_toolchain = native.toolchain vdso_linker_option = "-fuse-ld=gold " +def short_path(path): + return path + def proto_library(name, has_services = None, **kwargs): native.proto_library( name = name, @@ -85,13 +87,14 @@ def cc_binary(name, static = False, **kwargs): **kwargs ) -def go_binary(name, static = False, pure = False, **kwargs): +def go_binary(name, static = False, pure = False, x_defs = None, **kwargs): """Build a go binary. Args: name: name of the target. static: build a static binary. pure: build without cgo. + x_defs: additional definitions. **kwargs: rest of the arguments are passed to _go_binary. """ if static: @@ -100,6 +103,7 @@ def go_binary(name, static = False, pure = False, **kwargs): kwargs["pure"] = "on" _go_binary( name = name, + x_defs = x_defs, **kwargs ) @@ -143,26 +147,33 @@ def go_rule(rule, implementation, **kwargs): Returns: The result of invoking the rule. """ - attrs = kwargs.pop("attrs", []) + attrs = kwargs.pop("attrs", dict()) attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data") attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib") toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) -def go_context(ctx): +def go_test_library(target): + if hasattr(target.attr, "embed") and len(target.attr.embed) > 0: + return target.attr.embed[0] + return None + +def go_context(ctx, std = False): + # We don't change anything for the standard library analysis. All Go files + # are available in all instances. Note that this includes the standard + # library sources, which are analyzed by nogo. go_ctx = _go_context(ctx) return struct( go = go_ctx.go, env = go_ctx.env, - runfiles = depset([go_ctx.go] + go_ctx.sdk.tools + go_ctx.stdlib.libs), + nogo_args = [], + stdlib_srcs = go_ctx.sdk.srcs, + runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), goos = go_ctx.sdk.goos, goarch = go_ctx.sdk.goarch, tags = go_ctx.tags, ) -def py_requirement(name, direct = True): - return _py_requirement(name) - def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): values = { "@bazel_tools//src/conditions:linux_x86_64": amd64, @@ -180,3 +191,6 @@ def default_installer(): def default_net_util(): return [] # Nothing needed. + +def coreutil(): + return [] # Nothing needed. diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD index b8c3ddf44..8956be621 100644 --- a/tools/checkescape/BUILD +++ b/tools/checkescape/BUILD @@ -8,7 +8,6 @@ go_library( nogo = False, visibility = ["//tools/nogo:__subpackages__"], deps = [ - "//tools/nogo/data", "@org_golang_x_tools//go/analysis:go_tool_library", "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library", "@org_golang_x_tools//go/ssa:go_tool_library", diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index f8def4823..f5bba9980 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -61,20 +61,20 @@ package checkescape import ( "bufio" "bytes" + "flag" "fmt" "go/ast" "go/token" "go/types" "io" "os" + "os/exec" "path/filepath" - "strconv" "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" "golang.org/x/tools/go/ssa" - "gvisor.dev/gvisor/tools/nogo/data" ) const ( @@ -91,81 +91,20 @@ const ( exempt = "// escapes" ) -// escapingBuiltins are builtins known to escape. -// -// These are lowered at an earlier stage of compilation to explicit function -// calls, but are not available for recursive analysis. -var escapingBuiltins = []string{ - "append", - "makemap", - "newobject", - "mallocgc", -} - -// Analyzer defines the entrypoint. -var Analyzer = &analysis.Analyzer{ - Name: "checkescape", - Doc: "surfaces recursive escape analysis results", - Run: run, - Requires: []*analysis.Analyzer{buildssa.Analyzer}, - FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)}, -} - -// packageEscapeFacts is the set of all functions in a package, and whether or -// not they recursively pass escape analysis. -// -// All the type names for receivers are encoded in the full key. The key -// represents the fully qualified package and type name used at link time. -type packageEscapeFacts struct { - Funcs map[string][]Escape -} - -// AFact implements analysis.Fact.AFact. -func (*packageEscapeFacts) AFact() {} - -// CallSite is a single call site. -// -// These can be chained. -type CallSite struct { - LocalPos token.Pos - Resolved LinePosition -} - -// Escape is a single escape instance. -type Escape struct { - Reason EscapeReason - Detail string - Chain []CallSite -} - -// LinePosition is a low-resolution token.Position. -// -// This is used to match against possible exemptions placed in the source. -type LinePosition struct { - Filename string - Line int -} +var ( + // Binary is the binary under analysis. + // + // See Reader, below. + binary = flag.String("binary", "", "binary under analysis") -// String implements fmt.Stringer.String. -func (e *LinePosition) String() string { - return fmt.Sprintf("%s:%d", e.Filename, e.Line) -} + // Reader is the input stream. + // + // This may be set instead of Binary. + Reader io.Reader -// String implements fmt.Stringer.String. -// -// Note that this string will contain new lines. -func (e *Escape) String() string { - var b bytes.Buffer - fmt.Fprintf(&b, "%s", e.Reason.String()) - for i, cs := range e.Chain { - if i == len(e.Chain)-1 { - fmt.Fprintf(&b, "\n @ %s → %s", cs.Resolved.String(), e.Detail) - } else { - fmt.Fprintf(&b, "\n + %s", cs.Resolved.String()) - } - } - return b.String() -} + // Tool is the tool used to dump a binary. + tool = flag.String("dump_tool", "", "tool used to dump a binary") +) // EscapeReason is an escape reason. // @@ -173,12 +112,12 @@ func (e *Escape) String() string { type EscapeReason int const ( - interfaceInvoke EscapeReason = iota - unknownPackage - allocation + allocation EscapeReason = iota builtin + interfaceInvoke dynamicCall stackSplit + unknownPackage reasonCount // Count for below. ) @@ -189,17 +128,17 @@ const ( func (e EscapeReason) String() string { switch e { case interfaceInvoke: - return "interface: function invocation via interface" + return "interface: call to potentially allocating function" case unknownPackage: return "unknown: no package information available" case allocation: - return "heap: call to runtime heap allocation" + return "heap: explicit allocation" case builtin: - return "builtin: call to runtime builtin" + return "builtin: call to potentially allocating builtin" case dynamicCall: - return "dynamic: call via dynamic function" + return "dynamic: call to potentially allocating function" case stackSplit: - return "stack: stack split on function entry" + return "stack: possible split on function entry" default: panic(fmt.Sprintf("unknown reason: %d", e)) } @@ -228,52 +167,289 @@ var escapeTypes = func() map[string]EscapeReason { return result }() -// EscapeCount counts escapes. +// escapingBuiltins are builtins known to escape. +// +// These are lowered at an earlier stage of compilation to explicit function +// calls, but are not available for recursive analysis. +var escapingBuiltins = []string{ + "append", + "makemap", + "newobject", + "mallocgc", +} + +// packageEscapeFacts is the set of all functions in a package, and whether or +// not they recursively pass escape analysis. +// +// All the type names for receivers are encoded in the full key. The key +// represents the fully qualified package and type name used at link time. // -// It is used to avoid accumulating too many escapes for the same reason, for -// the same function. We limit each class to 3 instances (arbitrarily). -type EscapeCount struct { - byReason [reasonCount]uint32 +// Note that each Escapes object is a summary. Local findings may be reported +// using more detailed information. +type packageEscapeFacts struct { + Funcs map[string]Escapes +} + +// AFact implements analysis.Fact.AFact. +func (*packageEscapeFacts) AFact() {} + +// Analyzer includes specific results. +var Analyzer = &analysis.Analyzer{ + Name: "checkescape", + Doc: "escape analysis checks based on +checkescape annotations", + Run: runSelectEscapes, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, + FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)}, +} + +// EscapeAnalyzer includes all local escape results. +var EscapeAnalyzer = &analysis.Analyzer{ + Name: "checkescape", + Doc: "complete local escape analysis results (requires Analyzer facts)", + Run: runAllEscapes, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, } -// maxRecordsPerReason is the number of explicit records. +// LinePosition is a low-resolution token.Position. // -// See EscapeCount (and usage), and Record implementation. -const maxRecordsPerReason = 5 - -// Record records the reason or returns false if it should not be added. -func (ec *EscapeCount) Record(reason EscapeReason) bool { - ec.byReason[reason]++ - if ec.byReason[reason] > maxRecordsPerReason { - return false +// This is used to match against possible exemptions placed in the source. +type LinePosition struct { + Filename string + Line int +} + +// String implements fmt.Stringer.String. +func (e LinePosition) String() string { + return fmt.Sprintf("%s:%d", e.Filename, e.Line) +} + +// Simplified returns the simplified name. +func (e LinePosition) Simplified() string { + return fmt.Sprintf("%s:%d", filepath.Base(e.Filename), e.Line) +} + +// CallSite is a single call site. +// +// These can be chained. +type CallSite struct { + LocalPos token.Pos + Resolved LinePosition +} + +// IsValid indicates whether the CallSite is valid or not. +func (cs *CallSite) IsValid() bool { + return cs.LocalPos.IsValid() +} + +// Escapes is a collection of escapes. +// +// We record at most one escape for each reason, but record the number of +// escapes that were omitted. +// +// This object should be used to summarize all escapes for a single line (local +// analysis) or a single function (package facts). +// +// All fields are exported for gob. +type Escapes struct { + CallSites [reasonCount][]CallSite + Details [reasonCount]string + Omitted [reasonCount]int +} + +// add is called by Add and Merge. +func (es *Escapes) add(r EscapeReason, detail string, omitted int, callSites ...CallSite) { + if es.CallSites[r] != nil { + // We will either be replacing the current escape or dropping + // the added one. Either way, we increment omitted by the + // appropriate amount. + es.Omitted[r]++ + // If the callSites in the other is only a single element, then + // we will universally favor this. This provides the cleanest + // set of escapes to summarize, and more importantly: if there + if len(es.CallSites) == 1 || len(callSites) != 1 { + return + } + } + es.Details[r] = detail + es.CallSites[r] = callSites + es.Omitted[r] += omitted +} + +// Add adds a single escape. +func (es *Escapes) Add(r EscapeReason, detail string, callSites ...CallSite) { + es.add(r, detail, 0, callSites...) +} + +// IsEmpty returns true iff this Escapes is empty. +func (es *Escapes) IsEmpty() bool { + for _, cs := range es.CallSites { + if cs != nil { + return false + } } return true } +// Filter filters out all escapes except those matches the given reasons. +// +// If local is set, then non-local escapes will also be filtered. +func (es *Escapes) Filter(reasons []EscapeReason, local bool) { +FilterReasons: + for r := EscapeReason(0); r < reasonCount; r++ { + for i := 0; i < len(reasons); i++ { + if r == reasons[i] { + continue FilterReasons + } + } + // Zap this reason. + es.CallSites[r] = nil + es.Details[r] = "" + es.Omitted[r] = 0 + } + if !local { + return + } + for r := EscapeReason(0); r < reasonCount; r++ { + // Is does meet our local requirement? + if len(es.CallSites[r]) > 1 { + es.CallSites[r] = nil + es.Details[r] = "" + es.Omitted[r] = 0 + } + } +} + +// MergeWithCall merges these escapes with another. +// +// If callSite is nil, no call is added. +func (es *Escapes) MergeWithCall(other Escapes, callSite CallSite) { + for r := EscapeReason(0); r < reasonCount; r++ { + if other.CallSites[r] != nil { + // Construct our new call chain. + newCallSites := other.CallSites[r] + if callSite.IsValid() { + newCallSites = append([]CallSite{callSite}, newCallSites...) + } + // Add (potentially replacing) the underlying escape. + es.add(r, other.Details[r], other.Omitted[r], newCallSites...) + } + } +} + +// Reportf will call Reportf for each class of escapes. +func (es *Escapes) Reportf(pass *analysis.Pass) { + var b bytes.Buffer // Reused for all escapes. + for r := EscapeReason(0); r < reasonCount; r++ { + if es.CallSites[r] == nil { + continue + } + b.Reset() + fmt.Fprintf(&b, "%s ", r.String()) + if es.Omitted[r] > 0 { + fmt.Fprintf(&b, "(%d omitted) ", es.Omitted[r]) + } + for _, cs := range es.CallSites[r][1:] { + fmt.Fprintf(&b, "→ %s ", cs.Resolved.String()) + } + fmt.Fprintf(&b, "→ %s", es.Details[r]) + pass.Reportf(es.CallSites[r][0].LocalPos, b.String()) + } +} + +// MergeAll merges a sequence of escapes. +func MergeAll(others []Escapes) (es Escapes) { + for _, other := range others { + es.MergeWithCall(other, CallSite{}) + } + return +} + // loadObjdump reads the objdump output. // // This records if there is a call any function for every source line. It is // used only to remove false positives for escape analysis. The call will be // elided if escape analysis is able to put the object on the heap exclusively. -func loadObjdump() (map[LinePosition]string, error) { - f, err := os.Open(data.Objdump) +// +// Note that the map uses <basename.go>:<line> because that is all that is +// provided in the objdump format. Since this is all local, it is sufficient. +func loadObjdump() (map[string][]string, error) { + var ( + args []string + stdin io.Reader + ) + if *binary != "" { + args = append(args, *binary) + } else if Reader != nil { + stdin = Reader + } else { + // We have no input stream or binary. + return nil, fmt.Errorf("no binary or reader provided") + } + + // Construct our command. + cmd := exec.Command(*tool, args...) + cmd.Stdin = stdin + cmd.Stderr = os.Stderr + out, err := cmd.StdoutPipe() if err != nil { return nil, err } - defer f.Close() + if err := cmd.Start(); err != nil { + return nil, err + } + + // Identify calls by address or name. Note that this is also + // constructed dynamically below, as we encounted the addresses. + // This is because some of the functions (duffzero) may have + // jump targets in the middle of the function itself. + funcsAllowed := map[string]struct{}{ + "runtime.duffzero": struct{}{}, + "runtime.duffcopy": struct{}{}, + "runtime.racefuncenter": struct{}{}, + "runtime.gcWriteBarrier": struct{}{}, + "runtime.retpolineAX": struct{}{}, + "runtime.retpolineBP": struct{}{}, + "runtime.retpolineBX": struct{}{}, + "runtime.retpolineCX": struct{}{}, + "runtime.retpolineDI": struct{}{}, + "runtime.retpolineDX": struct{}{}, + "runtime.retpolineR10": struct{}{}, + "runtime.retpolineR11": struct{}{}, + "runtime.retpolineR12": struct{}{}, + "runtime.retpolineR13": struct{}{}, + "runtime.retpolineR14": struct{}{}, + "runtime.retpolineR15": struct{}{}, + "runtime.retpolineR8": struct{}{}, + "runtime.retpolineR9": struct{}{}, + "runtime.retpolineSI": struct{}{}, + "runtime.stackcheck": struct{}{}, + "runtime.settls": struct{}{}, + } + addrsAllowed := make(map[string]struct{}) // Build the map. - m := make(map[LinePosition]string) - r := bufio.NewReader(f) - var ( - lastField string - lastPos LinePosition - ) + nextFunc := "" // For funcsAllowed. + m := make(map[string][]string) + r := bufio.NewReader(out) +NextLine: for { line, err := r.ReadString('\n') if err != nil && err != io.EOF { return nil, err } + fields := strings.Fields(line) + + // Is this an "allowed" function definition? + if len(fields) >= 2 && fields[0] == "TEXT" { + nextFunc = strings.TrimSuffix(fields[1], "(SB)") + if _, ok := funcsAllowed[nextFunc]; !ok { + nextFunc = "" // Don't record addresses. + } + } + if nextFunc != "" && len(fields) > 2 { + // Save the given address (in hex form, as it appears). + addrsAllowed[fields[1]] = struct{}{} + } // We recognize lines corresponding to actual code (not the // symbol name or other metadata) and annotate them if they @@ -283,53 +459,70 @@ func loadObjdump() (map[LinePosition]string, error) { // // Lines look like this (including the first space): // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX - if len(line) > 0 && line[0] == ' ' { - fields := strings.Fields(line) + if len(fields) >= 5 && line[0] == ' ' { if !strings.Contains(fields[3], "CALL") { continue } + site := fields[0] + target := strings.TrimSuffix(fields[4], "(SB)") - // Ignore strings containing duffzero, which is just - // used by stack allocations for types that are large - // enough to warrant Duff's device. - if strings.Contains(line, "runtime.duffzero") { + // Ignore strings containing allowed functions. + if _, ok := funcsAllowed[target]; ok { continue } - - // Ignore the racefuncenter call, which is used for - // race builds. This does not escape. - if strings.Contains(line, "runtime.racefuncenter") { + if _, ok := addrsAllowed[target]; ok { continue } - - // Calculate the filename and line. Note that per the - // example above, the filename is not a fully qualified - // base, just the basename (what we require). - if fields[0] != lastField { - parts := strings.SplitN(fields[0], ":", 2) - lineNum, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return nil, err - } - lastPos = LinePosition{ - Filename: parts[0], - Line: int(lineNum), + if len(fields) > 5 { + // This may be a future relocation. Some + // objdump versions describe this differently. + // If it contains any of the functions allowed + // above as a string, we let it go. + softTarget := strings.Join(fields[5:], " ") + for name := range funcsAllowed { + if strings.Contains(softTarget, name) { + continue NextLine + } } - lastField = fields[0] - } - if _, ok := m[lastPos]; ok { - continue // Already marked. } - // Save the actual call for the detail. - m[lastPos] = strings.Join(fields[3:], " ") + // Does this exist already? + existing, ok := m[site] + if !ok { + existing = make([]string, 0, 1) + } + for _, other := range existing { + if target == other { + continue NextLine + } + } + existing = append(existing, target) + m[site] = existing // Update. } if err == io.EOF { break } } - return m, nil + // Zap any accidental false positives. + final := make(map[string][]string) + for site, calls := range m { + filteredCalls := make([]string, 0, len(calls)) + for _, call := range calls { + if _, ok := addrsAllowed[call]; ok { + continue // Omit this call. + } + filteredCalls = append(filteredCalls, call) + } + final[site] = filteredCalls + } + + // Wait for the dump to finish. + if err := cmd.Wait(); err != nil { + return nil, err + } + + return final, nil } // poser is a type that implements Pos. @@ -337,65 +530,148 @@ type poser interface { Pos() token.Pos } +// runSelectEscapes runs with only select escapes. +func runSelectEscapes(pass *analysis.Pass) (interface{}, error) { + return run(pass, false) +} + +// runAllEscapes runs with all escapes included. +func runAllEscapes(pass *analysis.Pass) (interface{}, error) { + return run(pass, true) +} + +// findReasons extracts reasons from the function. +func findReasons(pass *analysis.Pass, fdecl *ast.FuncDecl) ([]EscapeReason, bool, map[EscapeReason]bool) { + // Is there a comment? + if fdecl.Doc == nil { + return nil, false, nil + } + var ( + reasons []EscapeReason + local bool + testReasons = make(map[EscapeReason]bool) // reason -> local? + ) + // Scan all lines. + found := false + for _, c := range fdecl.Doc.List { + // Does the comment contain a +checkescape line? + if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) { + continue + } + if c.Text == magic { + // Default: hard reasons, local only. + reasons = hardReasons + local = true + } else if strings.HasPrefix(c.Text, magicParams) { + // Extract specific reasons. + types := strings.Split(c.Text[len(magicParams):], ",") + found = true // For below. + for i := 0; i < len(types); i++ { + if types[i] == "local" { + // Limit search to local escapes. + local = true + } else if types[i] == "all" { + // Append all reasons. + reasons = append(reasons, allReasons...) + } else if types[i] == "hard" { + // Append all hard reasons. + reasons = append(reasons, hardReasons...) + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + reasons = append(reasons, r) + } + } + } else if strings.HasPrefix(c.Text, testMagic) { + types := strings.Split(c.Text[len(testMagic):], ",") + local := false + for i := 0; i < len(types); i++ { + if types[i] == "local" { + local = true + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + if v, ok := testReasons[r]; ok && v { + // Already registered as local. + continue + } + testReasons[r] = local + } + } + } + } + if len(reasons) == 0 && found { + // A magic annotation was provided, but no reasons. + pass.Reportf(fdecl.Pos(), "no reasons provided") + } + return reasons, local, testReasons +} + // run performs the analysis. -func run(pass *analysis.Pass) (interface{}, error) { +func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) { calls, err := loadObjdump() if err != nil { return nil, err } - pef := packageEscapeFacts{ - Funcs: make(map[string][]Escape), - } + allEscapes := make(map[string][]Escapes) + mergedEscapes := make(map[string]Escapes) linePosition := func(inst, parent poser) LinePosition { p := pass.Fset.Position(inst.Pos()) if (p.Filename == "" || p.Line == 0) && parent != nil { p = pass.Fset.Position(parent.Pos()) } return LinePosition{ - Filename: filepath.Base(p.Filename), + Filename: p.Filename, Line: p.Line, } } - hasCall := func(inst poser) (string, bool) { - p := linePosition(inst, nil) - s, ok := calls[p] - return s, ok - } callSite := func(inst ssa.Instruction) CallSite { return CallSite{ LocalPos: inst.Pos(), Resolved: linePosition(inst, inst.Parent()), } } - escapes := func(reason EscapeReason, detail string, inst ssa.Instruction, ec *EscapeCount) []Escape { - if !ec.Record(reason) { - return nil // Skip. - } - es := Escape{ - Reason: reason, - Detail: detail, - Chain: []CallSite{callSite(inst)}, + hasCall := func(inst poser) (string, bool) { + p := linePosition(inst, nil) + s, ok := calls[p.Simplified()] + if !ok { + return "", false } - return []Escape{es} + // Join all calls together. + return strings.Join(s, " or "), true } - resolve := func(sub []Escape, inst ssa.Instruction, ec *EscapeCount) (es []Escape) { - for _, e := range sub { - if !ec.Record(e.Reason) { - continue // Skip. + state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + + // Build the exception list. + exemptions := make(map[LinePosition]string) + for _, f := range pass.Files { + for _, cg := range f.Comments { + for _, c := range cg.List { + p := pass.Fset.Position(c.Slash) + if strings.HasPrefix(strings.ToLower(c.Text), exempt) { + exemptions[LinePosition{ + Filename: p.Filename, + Line: p.Line, + }] = c.Text[len(exempt):] + } } - es = append(es, Escape{ - Reason: e.Reason, - Detail: e.Detail, - Chain: append([]CallSite{callSite(inst)}, e.Chain...), - }) } - return es } - state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) - var loadFunc func(*ssa.Function) []Escape // Used below. - - analyzeInstruction := func(inst ssa.Instruction, ec *EscapeCount) []Escape { + var loadFunc func(*ssa.Function) Escapes // Used below. + analyzeInstruction := func(inst ssa.Instruction) (es Escapes) { + cs := callSite(inst) + if _, ok := exemptions[cs.Resolved]; ok { + return // No escape. + } switch x := inst.(type) { case *ssa.Call: if x.Call.IsInvoke() { @@ -404,19 +680,15 @@ func run(pass *analysis.Pass) (interface{}, error) { // not, since we don't know the underlying // type. call, _ := hasCall(inst) - return escapes(interfaceInvoke, call, inst, ec) + es.Add(interfaceInvoke, call, cs) + return } switch x := x.Call.Value.(type) { case *ssa.Function: if x.Pkg == nil { // Can't resolve the package. - return escapes(unknownPackage, "no package", inst, ec) - } - - // Atomic functions are instrinics. We can - // assume that they don't escape. - if x.Pkg.Pkg.Name() == "atomic" { - return nil + es.Add(unknownPackage, "no package", cs) + return } // Is this a local function? If yes, call the @@ -424,7 +696,8 @@ func run(pass *analysis.Pass) (interface{}, error) { // local escapes are the escapes found in the // local function. if x.Pkg.Pkg == pass.Pkg { - return resolve(loadFunc(x), inst, ec) + es.MergeWithCall(loadFunc(x), cs) + return } // Recursively collect information from @@ -433,22 +706,26 @@ func run(pass *analysis.Pass) (interface{}, error) { if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) { // Unable to import the dependency; we must // declare these as escaping. - return escapes(unknownPackage, "no analysis", inst, ec) + es.Add(unknownPackage, "no analysis", cs) + return } // The escapes of this instruction are the // escapes of the called function directly. - return resolve(imp.Funcs[x.RelString(x.Pkg.Pkg)], inst, ec) + // Note that this may record many escapes. + es.MergeWithCall(imp.Funcs[x.RelString(x.Pkg.Pkg)], cs) + return case *ssa.Builtin: // Ignore elided escapes. if _, has := hasCall(inst); !has { - return nil + return } // Check if the builtin is escaping. for _, name := range escapingBuiltins { if x.Name() == name { - return escapes(builtin, name, inst, ec) + es.Add(builtin, name, cs) + return } } default: @@ -457,82 +734,87 @@ func run(pass *analysis.Pass) (interface{}, error) { // dispatches. We cannot actually look up what // this refers to using static analysis alone. call, _ := hasCall(inst) - return escapes(dynamicCall, call, inst, ec) + es.Add(dynamicCall, call, cs) } case *ssa.Alloc: // Ignore non-heap allocations. if !x.Heap { - return nil + return } // Ignore elided escapes. call, has := hasCall(inst) if !has { - return nil + return } // This is a real heap allocation. - return escapes(allocation, call, inst, ec) + es.Add(allocation, call, cs) case *ssa.MakeMap: - return escapes(builtin, "makemap", inst, ec) + es.Add(builtin, "makemap", cs) case *ssa.MakeSlice: - return escapes(builtin, "makeslice", inst, ec) + es.Add(builtin, "makeslice", cs) case *ssa.MakeClosure: - return escapes(builtin, "makeclosure", inst, ec) + es.Add(builtin, "makeclosure", cs) case *ssa.MakeChan: - return escapes(builtin, "makechan", inst, ec) + es.Add(builtin, "makechan", cs) } - return nil // No escapes. + return } - var analyzeBasicBlock func(*ssa.BasicBlock, *EscapeCount) []Escape // Recursive. - analyzeBasicBlock = func(block *ssa.BasicBlock, ec *EscapeCount) (rval []Escape) { + var analyzeBasicBlock func(*ssa.BasicBlock) []Escapes // Recursive. + analyzeBasicBlock = func(block *ssa.BasicBlock) (rval []Escapes) { for _, inst := range block.Instrs { - rval = append(rval, analyzeInstruction(inst, ec)...) + if es := analyzeInstruction(inst); !es.IsEmpty() { + rval = append(rval, es) + } } - return rval // N.B. may be empty. + return } - loadFunc = func(fn *ssa.Function) []Escape { + loadFunc = func(fn *ssa.Function) Escapes { // Is this already available? name := fn.RelString(pass.Pkg) - if es, ok := pef.Funcs[name]; ok { + if es, ok := mergedEscapes[name]; ok { return es } // In the case of a true cycle, we assume that the current - // function itself has no escapes until the rest of the - // analysis is complete. This will trip the above in the case - // of a cycle of any kind. - pef.Funcs[name] = nil + // function itself has no escapes. + // + // When evaluating the function again, the proper escapes will + // be filled in here. + allEscapes[name] = nil + mergedEscapes[name] = Escapes{} // Perform the basic analysis. - var ( - es []Escape - ec EscapeCount - ) + var es []Escapes if fn.Recover != nil { - es = append(es, analyzeBasicBlock(fn.Recover, &ec)...) + es = append(es, analyzeBasicBlock(fn.Recover)...) } for _, block := range fn.Blocks { - es = append(es, analyzeBasicBlock(block, &ec)...) + es = append(es, analyzeBasicBlock(block)...) } // Check for a stack split. if call, has := hasCall(fn); has { - es = append(es, Escape{ - Reason: stackSplit, - Detail: call, - Chain: []CallSite{CallSite{ - LocalPos: fn.Pos(), - Resolved: linePosition(fn, fn.Parent()), - }}, + var ss Escapes + ss.Add(stackSplit, call, CallSite{ + LocalPos: fn.Pos(), + Resolved: linePosition(fn, fn.Parent()), }) + es = append(es, ss) } // Save the result and return. - pef.Funcs[name] = es - return es + // + // Note that we merge the result when saving to the facts. It + // doesn't really matter the specific escapes, as long as we + // have recorded all the appropriate classes of escapes. + summary := MergeAll(es) + allEscapes[name] = es + mergedEscapes[name] = summary + return summary } // Complete all local functions. @@ -540,173 +822,76 @@ func run(pass *analysis.Pass) (interface{}, error) { loadFunc(fn) } - // Build the exception list. - exemptions := make(map[LinePosition]string) - for _, f := range pass.Files { - for _, cg := range f.Comments { - for _, c := range cg.List { - p := pass.Fset.Position(c.Slash) - if strings.HasPrefix(strings.ToLower(c.Text), exempt) { - exemptions[LinePosition{ - Filename: filepath.Base(p.Filename), - Line: p.Line, - }] = c.Text[len(exempt):] - } - } - } + if !localEscapes { + // Export all findings for future packages. We only do this in + // non-local escapes mode, and expect to run this analysis + // after the SelectAnalysis. + pass.ExportPackageFact(&packageEscapeFacts{ + Funcs: mergedEscapes, + }) } - // Delete everything matching the excemtions. - // - // This has the implication that exceptions are applied recursively, - // since this now modified set is what will be saved. - for name, escapes := range pef.Funcs { - var newEscapes []Escape - for _, escape := range escapes { - isExempt := false - for line, _ := range exemptions { - // Note that an exemption applies if it is - // marked as an exemption anywhere in the call - // chain. It need not be marked as escapes in - // the function itself, nor in the top-level - // caller. - for _, callSite := range escape.Chain { - if callSite.Resolved == line { - isExempt = true - break - } - } - if isExempt { - break - } - } - if !isExempt { - // Record this escape; not an exception. - newEscapes = append(newEscapes, escape) - } - } - pef.Funcs[name] = newEscapes // Update. - } - - // Export all findings for future packages. - pass.ExportPackageFact(&pef) - // Scan all functions for violations. for _, f := range pass.Files { // Scan all declarations. for _, decl := range f.Decls { - fdecl, ok := decl.(*ast.FuncDecl) // Function declaration? + fdecl, ok := decl.(*ast.FuncDecl) if !ok { continue } - // Is there a comment? - if fdecl.Doc == nil { - continue - } var ( reasons []EscapeReason - found bool local bool - testReasons = make(map[EscapeReason]bool) // reason -> local? + testReasons map[EscapeReason]bool ) - // Does the comment contain a +checkescape line? - for _, c := range fdecl.Doc.List { - if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) { - continue - } - if c.Text == magic { - // Default: hard reasons, local only. - reasons = hardReasons - local = true - } else if strings.HasPrefix(c.Text, magicParams) { - // Extract specific reasons. - types := strings.Split(c.Text[len(magicParams):], ",") - found = true // For below. - for i := 0; i < len(types); i++ { - if types[i] == "local" { - // Limit search to local escapes. - local = true - } else if types[i] == "all" { - // Append all reasons. - reasons = append(reasons, allReasons...) - } else if types[i] == "hard" { - // Append all hard reasons. - reasons = append(reasons, hardReasons...) - } else { - r, ok := escapeTypes[types[i]] - if !ok { - // This is not a valid escape reason. - pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) - continue - } - reasons = append(reasons, r) - } - } - } else if strings.HasPrefix(c.Text, testMagic) { - types := strings.Split(c.Text[len(testMagic):], ",") - local := false - for i := 0; i < len(types); i++ { - if types[i] == "local" { - local = true - } else { - r, ok := escapeTypes[types[i]] - if !ok { - // This is not a valid escape reason. - pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) - continue - } - if v, ok := testReasons[r]; ok && v { - // Already registered as local. - continue - } - testReasons[r] = local - } - } - } - } - if len(reasons) == 0 && found { - // A magic annotation was provided, but no reasons. - pass.Reportf(fdecl.Pos(), "no reasons provided") - continue + if localEscapes { + // Find all hard escapes. + reasons = hardReasons + } else { + // Find all declared reasons. + reasons, local, testReasons = findReasons(pass, fdecl) } // Scan for matches. fn := pass.TypesInfo.Defs[fdecl.Name].(*types.Func) - name := state.Pkg.Prog.FuncValue(fn).RelString(pass.Pkg) - es, ok := pef.Funcs[name] - if !ok { + fv := state.Pkg.Prog.FuncValue(fn) + if fv == nil { + continue + } + name := fv.RelString(pass.Pkg) + all, allOk := allEscapes[name] + merged, mergedOk := mergedEscapes[name] + if !allOk || !mergedOk { pass.Reportf(fdecl.Pos(), "internal error: function %s not found.", name) continue } - for _, e := range es { - for _, r := range reasons { - // Is does meet our local requirement? - if local && len(e.Chain) > 1 { - continue - } - // Does this match the reason? Emit - // with a full stack trace that - // explains why this violates our - // constraints. - if e.Reason == r { - pass.Reportf(e.Chain[0].LocalPos, "%s", e.String()) - } - } + + // Filter reasons and report. + // + // For the findings, we use all escapes. + for _, es := range all { + es.Filter(reasons, local) + es.Reportf(pass) } // Scan for test (required) matches. + // + // For tests we need only the merged escapes. testReasonsFound := make(map[EscapeReason]bool) - for _, e := range es { + for r := EscapeReason(0); r < reasonCount; r++ { + if merged.CallSites[r] == nil { + continue + } // Is this local? - local, ok := testReasons[e.Reason] - wantLocal := len(e.Chain) == 1 - testReasonsFound[e.Reason] = wantLocal + wantLocal, ok := testReasons[r] + isLocal := len(merged.CallSites[r]) == 1 + testReasonsFound[r] = isLocal if !ok { continue } - if local == wantLocal { - delete(testReasons, e.Reason) + if isLocal == wantLocal { + delete(testReasons, r) } } for reason, local := range testReasons { @@ -714,10 +899,8 @@ func run(pass *analysis.Pass) (interface{}, error) { pass.Reportf(fdecl.Pos(), fmt.Sprintf("testescapes not found: reason=%s, local=%t", reason, local)) } if len(testReasons) > 0 { - // Dump all reasons found to help in debugging. - for _, e := range es { - pass.Reportf(e.Chain[0].LocalPos, "escape found: %s", e.String()) - } + // Report for debugging. + merged.Reportf(pass) } } } diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go index 68d3f72cc..27991649f 100644 --- a/tools/checkescape/test1/test1.go +++ b/tools/checkescape/test1/test1.go @@ -17,7 +17,6 @@ package test1 import ( "fmt" - "reflect" ) // Interface is a generic interface. @@ -163,20 +162,6 @@ func dynamicRec(f func()) { Dynamic(f) } -// +mustescape:local,unknown -//go:noinline -//go:nosplit -func Unknown() { - _ = reflect.TypeOf((*Type)(nil)) // Does not actually escape. -} - -// +mustescape:unknown -//go:noinline -//go:nosplit -func unknownRec() { - Unknown() -} - //go:noinline //go:nosplit func internalFunc() { @@ -190,6 +175,7 @@ func Split() { // +mustescape:stack //go:noinline +//go:nosplit func splitRec() { Split() } diff --git a/tools/checkescape/test2/test2.go b/tools/checkescape/test2/test2.go index 7fce3e3be..067d5a1f4 100644 --- a/tools/checkescape/test2/test2.go +++ b/tools/checkescape/test2/test2.go @@ -81,14 +81,9 @@ func dynamicCrossPkg(f func()) { test1.Dynamic(f) } -// +mustescape:unknown -//go:noinline -func unknownCrossPkg() { - test1.Unknown() -} - // +mustescape:stack //go:noinline +//go:nosplit func splitCrosssPkt() { test1.Split() } diff --git a/tools/defs.bzl b/tools/defs.bzl index 40afcdb79..079ab806f 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,13 +7,14 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") load("//tools/nogo:defs.bzl", "nogo_test") # Delegate directly. build_test = _build_test +bzl_library = _bzl_library cc_binary = _cc_binary cc_flags_supplier = _cc_flags_supplier cc_grpc_library = _cc_grpc_library @@ -26,37 +27,53 @@ gbenchmark = _gbenchmark gazelle = _gazelle go_embed_data = _go_embed_data go_path = _go_path -go_test = _go_test gtest = _gtest grpcpp = _grpcpp loopback = _loopback pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_binary = _py_binary -py_library = _py_library -py_requirement = _py_requirement -py_test = _py_test select_arch = _select_arch select_system = _select_system +short_path = _short_path rbe_platform = _rbe_platform rbe_toolchain = _rbe_toolchain vdso_linker_option = _vdso_linker_option +coreutil = _coreutil # Platform options. default_platform = _default_platform platforms = _platforms -def go_binary(name, **kwargs): +def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, **kwargs): """Wraps the standard go_binary. Args: name: the rule name. + nogo: enable nogo analysis. + pure: build a pure Go (no CGo) binary. + static: build a static binary. + x_defs: additional linker definitions. **kwargs: standard go_binary arguments. """ _go_binary( name = name, + pure = pure, + static = static, + x_defs = x_defs, **kwargs ) + if nogo: + # Note that the nogo rule applies only for go_library and go_test + # targets, therefore we construct a library from the binary sources. + _go_library( + name = name + "_nogo_library", + **kwargs + ) + nogo_test( + name = name + "_nogo", + deps = [":" + name + "_nogo_library"], + ) def calculate_sets(srcs): """Calculates special Go sets for templates. @@ -120,6 +137,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F stateify: whether statify is enabled (default: true). marshal: whether marshal is enabled (default: false). marshal_debug: whether the gomarshal tools emits debugging output (default: false). + nogo: enable nogo analysis. **kwargs: standard go_library arguments. """ all_srcs = srcs @@ -197,12 +215,33 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F for (suffix, _) in marshal_sets.items(): _go_test( name = name + suffix + "_abi_autogen_test", - srcs = [name + suffix + "_abi_autogen_test.go"], + srcs = [ + name + suffix + "_abi_autogen_test.go", + name + suffix + "_abi_autogen_unconditional_test.go", + ], library = ":" + name, deps = marshal_test_deps, **kwargs ) +def go_test(name, nogo = True, **kwargs): + """Wraps the standard go_test. + + Args: + name: the rule name. + nogo: enable nogo analysis. + **kwargs: standard go_test arguments. + """ + _go_test( + name = name, + **kwargs + ) + if nogo: + nogo_test( + name = name + "_nogo", + deps = [":" + name], + ) + def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): """Wraps the standard proto_library. diff --git a/tools/github/BUILD b/tools/github/BUILD new file mode 100644 index 000000000..aad088d13 --- /dev/null +++ b/tools/github/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "github", + srcs = ["main.go"], + nogo = False, + deps = [ + "//tools/github/nogo", + "//tools/github/reviver", + "@com_github_google_go_github_v28//github:go_default_library", + "@org_golang_x_oauth2//:go_default_library", + ], +) diff --git a/tools/github/main.go b/tools/github/main.go new file mode 100644 index 000000000..7a74dc033 --- /dev/null +++ b/tools/github/main.go @@ -0,0 +1,162 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary github is the entry point for GitHub utilities. +package main + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "os" + "os/exec" + "strings" + + "github.com/google/go-github/github" + "golang.org/x/oauth2" + "gvisor.dev/gvisor/tools/github/nogo" + "gvisor.dev/gvisor/tools/github/reviver" +) + +var ( + owner string + repo string + tokenFile string + path string + commit string + dryRun bool +) + +// Keep the options simple for now. Supports only a single path and repo. +func init() { + flag.StringVar(&owner, "owner", "", "GitHub project org/owner (required, except nogo dry-run)") + flag.StringVar(&repo, "repo", "", "GitHub repo (required, except nogo dry-run)") + flag.StringVar(&tokenFile, "oauth-token-file", "", "file containing the GitHub token (or GITHUB_TOKEN is set)") + flag.StringVar(&path, "path", ".", "path to scan (required for revive and nogo)") + flag.StringVar(&commit, "commit", "", "commit to associated (required for nogo, except dry-run)") + flag.BoolVar(&dryRun, "dry-run", false, "just print changes to be made") +} + +func main() { + // Set defaults from the environment. + repository := os.Getenv("GITHUB_REPOSITORY") + if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 { + owner = parts[0] + repo = parts[1] + } + + // Parse flags. + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "usage: %s [options] <command>\n", os.Args[0]) + fmt.Fprintf(flag.CommandLine.Output(), "commands: revive, nogo\n") + flag.PrintDefaults() + } + flag.Parse() + args := flag.Args() + if len(args) != 1 { + fmt.Fprintf(flag.CommandLine.Output(), "extra arguments: %s\n", strings.Join(args[1:], ", ")) + flag.Usage() + os.Exit(1) + } + + // Check for mandatory parameters. + command := args[0] + if len(owner) == 0 && (command != "nogo" || !dryRun) { + fmt.Fprintln(flag.CommandLine.Output(), "missing --owner option.") + flag.Usage() + os.Exit(1) + } + if len(repo) == 0 && (command != "nogo" || !dryRun) { + fmt.Fprintln(flag.CommandLine.Output(), "missing --repo option.") + flag.Usage() + os.Exit(1) + } + if len(path) == 0 { + fmt.Fprintln(flag.CommandLine.Output(), "missing --path option.") + flag.Usage() + os.Exit(1) + } + + // The access token may be passed as a file so it doesn't show up in + // command line arguments. It also may be provided through the + // environment to faciliate use through GitHub's CI system. + token := os.Getenv("GITHUB_TOKEN") + if len(tokenFile) != 0 { + bytes, err := ioutil.ReadFile(tokenFile) + if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + token = string(bytes) + } + var client *github.Client + if len(token) == 0 { + // Client is unauthenticated. + client = github.NewClient(nil) + } else { + // Using the above token. + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + ) + tc := oauth2.NewClient(context.Background(), ts) + client = github.NewClient(tc) + } + + switch command { + case "revive": + // Load existing GitHub bugs. + bugger, err := reviver.NewGitHubBugger(client, owner, repo, dryRun) + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting github issues: %v\n", err) + os.Exit(1) + } + // Scan the provided path. + rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) + if errs := rev.Run(); len(errs) > 0 { + fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) + for _, err := range errs { + fmt.Fprintf(os.Stderr, "\t%v\n", err) + } + os.Exit(1) + } + case "nogo": + // Did we get a commit? Try to extract one. + if len(commit) == 0 && !dryRun { + cmd := exec.Command("git", "rev-parse", "HEAD") + revBytes, err := cmd.Output() + if err != nil { + fmt.Fprintf(flag.CommandLine.Output(), "missing --commit option, unable to infer: %v\n", err) + flag.Usage() + os.Exit(1) + } + commit = strings.TrimSpace(string(revBytes)) + } + // Scan all findings. + poster := nogo.NewFindingsPoster(client, owner, repo, commit, dryRun) + if err := poster.Walk(path); err != nil { + fmt.Fprintln(os.Stderr, "Error finding nogo findings:", err) + os.Exit(1) + } + // Post to GitHub. + if err := poster.Post(); err != nil { + fmt.Fprintln(os.Stderr, "Error posting nogo findings:", err) + } + default: + // Not a known command. + fmt.Fprintf(flag.CommandLine.Output(), "unknown command: %s\n", command) + flag.Usage() + os.Exit(1) + } +} diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD new file mode 100644 index 000000000..0633eaf19 --- /dev/null +++ b/tools/github/nogo/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "nogo", + srcs = ["nogo.go"], + nogo = False, + visibility = [ + "//tools/github:__subpackages__", + ], + deps = [ + "//tools/nogo/util", + "@com_github_google_go_github_v28//github:go_default_library", + ], +) diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go new file mode 100644 index 000000000..b70dfe63b --- /dev/null +++ b/tools/github/nogo/nogo.go @@ -0,0 +1,126 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nogo provides nogo-related utilities. +package nogo + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/go-github/github" + "gvisor.dev/gvisor/tools/nogo/util" +) + +// FindingsPoster is a simple wrapper around the GitHub api. +type FindingsPoster struct { + owner string + repo string + commit string + dryRun bool + startTime time.Time + + findings map[util.Finding]struct{} + client *github.Client +} + +// NewFindingsPoster returns a object that can post findings. +func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun bool) *FindingsPoster { + return &FindingsPoster{ + owner: owner, + repo: repo, + commit: commit, + dryRun: dryRun, + startTime: time.Now(), + findings: make(map[util.Finding]struct{}), + client: client, + } +} + +// Walk walks the given path tree for findings files. +func (p *FindingsPoster) Walk(path string) error { + return filepath.Walk(path, func(filename string, info os.FileInfo, err error) error { + if err != nil { + return err + } + // Skip any directories or files not ending in .findings. + if !strings.HasSuffix(filename, ".findings") || info.IsDir() { + return nil + } + findings, err := util.ExtractFindingsFromFile(filename) + if err != nil { + return err + } + // Add all findings to the list. We use a map to ensure + // that each finding is unique. + for _, finding := range findings { + p.findings[finding] = struct{}{} + } + return nil + }) +} + +// Post posts all results to the GitHub API as a check run. +func (p *FindingsPoster) Post() error { + // Just show results? + if p.dryRun { + for finding, _ := range p.findings { + // Pretty print, so that this is useful for debugging. + fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message) + } + return nil + } + + // Construct the message. + title := "nogo" + count := len(p.findings) + status := "completed" + conclusion := "success" + if count > 0 { + conclusion = "failure" // Contains errors. + } + summary := fmt.Sprintf("%d findings.", count) + opts := github.CreateCheckRunOptions{ + Name: title, + HeadSHA: p.commit, + Status: &status, + Conclusion: &conclusion, + StartedAt: &github.Timestamp{p.startTime}, + CompletedAt: &github.Timestamp{time.Now()}, + Output: &github.CheckRunOutput{ + Title: &title, + Summary: &summary, + AnnotationsCount: &count, + }, + } + annotationLevel := "failure" // Always. + for finding, _ := range p.findings { + opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{ + Path: &finding.Path, + StartLine: &finding.Line, + EndLine: &finding.Line, + Message: &finding.Message, + Title: &finding.Category, + AnnotationLevel: &annotationLevel, + }) + } + + // Post to GitHub. + _, _, err := p.client.Checks.CreateCheckRun(context.Background(), p.owner, p.repo, opts) + return err +} diff --git a/tools/github/reviver/BUILD b/tools/github/reviver/BUILD new file mode 100644 index 000000000..7d78480a7 --- /dev/null +++ b/tools/github/reviver/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "reviver", + srcs = [ + "github.go", + "reviver.go", + ], + nogo = False, + visibility = [ + "//tools/github:__subpackages__", + ], + deps = ["@com_github_google_go_github_v28//github:go_default_library"], +) + +go_test( + name = "reviver_test", + size = "small", + srcs = [ + "github_test.go", + "reviver_test.go", + ], + library = ":reviver", + nogo = False, +) diff --git a/tools/issue_reviver/github/github.go b/tools/github/reviver/github.go index e07949c8f..a95df0fb6 100644 --- a/tools/issue_reviver/github/github.go +++ b/tools/github/reviver/github.go @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package github implements reviver.Bugger interface on top of Github issues. -package github +package reviver import ( "context" @@ -23,12 +22,10 @@ import ( "time" "github.com/google/go-github/github" - "golang.org/x/oauth2" - "gvisor.dev/gvisor/tools/issue_reviver/reviver" ) -// Bugger implements reviver.Bugger interface for github issues. -type Bugger struct { +// GitHubBugger implements Bugger interface for github issues. +type GitHubBugger struct { owner string repo string dryRun bool @@ -37,36 +34,25 @@ type Bugger struct { issues map[int]*github.Issue } -// NewBugger creates a new Bugger. -func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) { - b := &Bugger{ +// NewGitHubBugger creates a new GitHubBugger. +func NewGitHubBugger(client *github.Client, owner, repo string, dryRun bool) (*GitHubBugger, error) { + b := &GitHubBugger{ owner: owner, repo: repo, dryRun: dryRun, issues: map[int]*github.Issue{}, + client: client, } - if err := b.load(token); err != nil { + if err := b.load(); err != nil { return nil, err } return b, nil } -func (b *Bugger) load(token string) error { - ctx := context.Background() - if len(token) == 0 { - fmt.Print("No OAUTH token provided, using unauthenticated account.\n") - b.client = github.NewClient(nil) - } else { - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: token}, - ) - tc := oauth2.NewClient(ctx, ts) - b.client = github.NewClient(tc) - } - +func (b *GitHubBugger) load() error { err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) { opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts} - tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts) + tmps, resp, err := b.client.Issues.ListByRepo(context.Background(), b.owner, b.repo, opts) if err != nil { return resp, err } @@ -83,20 +69,15 @@ func (b *Bugger) load(token string) error { return nil } -// Activate implements reviver.Bugger. -func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) { - const prefix = "gvisor.dev/issue/" - - // First check if I can handle the TODO. - idStr := strings.TrimPrefix(todo.Issue, prefix) - if len(todo.Issue) == len(idStr) { - return false, nil - } - - id, err := strconv.Atoi(idStr) +// Activate implements Bugger.Activate. +func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { + id, err := parseIssueNo(todo.Issue) if err != nil { return true, err } + if id <= 0 { + return false, nil + } // Check against active issues cache. if _, ok := b.issues[id]; ok { @@ -115,7 +96,7 @@ func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) { l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment) } fmt.Fprintf(&comment, - "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%d%%22)", b.owner, b.repo, prefix, id) + "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%%22)", b.owner, b.repo, todo.Issue) if b.dryRun { fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String()) @@ -140,6 +121,23 @@ func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) { return true, nil } +// parseIssueNo parses the issue number out of the issue url. +func parseIssueNo(url string) (int, error) { + const prefix = "gvisor.dev/issue/" + + // First check if I can handle the TODO. + idStr := strings.TrimPrefix(url, prefix) + if len(url) == len(idStr) { + return 0, nil + } + + id, err := strconv.ParseInt(strings.TrimRight(idStr, "/"), 10, 64) + if err != nil { + return 0, err + } + return int(id), nil +} + func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error { opts := github.ListOptions{PerPage: 1000} for { diff --git a/tools/github/reviver/github_test.go b/tools/github/reviver/github_test.go new file mode 100644 index 000000000..5df7e3624 --- /dev/null +++ b/tools/github/reviver/github_test.go @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reviver + +import ( + "testing" +) + +func TestParseIssueNo(t *testing.T) { + testCases := []struct { + issue string + expectErr bool + expected int + }{ + { + issue: "gvisor.dev/issue/123", + expected: 123, + }, + { + issue: "gvisor.dev/issue/123/", + expected: 123, + }, + { + issue: "not a url", + expected: 0, + }, + { + issue: "gvisor.dev/issue//", + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.issue, func(t *testing.T) { + id, err := parseIssueNo(tc.issue) + if err != nil && !tc.expectErr { + t.Errorf("got error: %v", err) + } else if tc.expected != id { + t.Errorf("got: %v, want: %v", id, tc.expected) + } + }) + } +} diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/github/reviver/reviver.go index 682db0c01..2af7f0d59 100644 --- a/tools/issue_reviver/reviver/reviver.go +++ b/tools/github/reviver/reviver.go @@ -26,7 +26,7 @@ import ( "sync" ) -// This is how a TODO looks like. +// regexTodo matches a TODO or FIXME comment. var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`) // Bugger interface is called for every TODO found in the code. If it can handle diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go index a9fb1f9f1..a9fb1f9f1 100644 --- a/tools/issue_reviver/reviver/reviver_test.go +++ b/tools/github/reviver/reviver_test.go diff --git a/tools/go_branch.sh b/tools/go_branch.sh index 093de89b4..e5c060024 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -40,10 +40,15 @@ trap finish EXIT # Record the current working commit. declare -r head=$(git describe --always) -# We expect to have an existing go branch that we will use as the basis for -# this commit. That branch may be empty, but it must exist. +# We expect to have an existing go branch that we will use as the basis for this +# commit. That branch may be empty, but it must exist. We search for this branch +# using the local branch, the "origin" branch, and other remotes, in order. git fetch --all -declare -r go_branch=$(git show-ref --hash go) +declare -r go_branch=$( \ + git show-ref --hash refs/heads/go || \ + git show-ref --hash refs/remotes/origin/go || \ + git show-ref --hash go | head -n 1 \ +) # Clone the current repository to the temporary directory, and check out the # current go_branch directory. We move to the new repository for convenience. @@ -66,6 +71,11 @@ git checkout -b go "${go_branch}" git merge --no-commit --strategy ours ${head} || \ git merge --allow-unrelated-histories --no-commit --strategy ours ${head} +# Normalize the permissions on the old branch. Note that they should be +# normalized if constructed by this tool, but we do so before the rsync. +find . -type f -exec chmod 0644 {} \; +find . -type d -exec chmod 0755 {} \; + # Sync the entire gopath_dir. rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" . @@ -86,7 +96,11 @@ EOF # There are a few solitary files that can get left behind due to the way bazel # constructs the gopath target. Note that we don't find all Go files here # because they may correspond to unused templates, etc. -cp "${repo_orig}"/runsc/*.go runsc/ +declare -ar binaries=( "runsc" "shim/v1" "shim/v2" ) +for target in "${binaries[@]}"; do + mkdir -p "${target}" + cp "${repo_orig}/${target}"/*.go "${target}/" +done # Normalize all permissions. The way bazel constructs the :gopath tree may leave # some strange permissions on files. We don't have anything in this tree that diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 32a949c93..807c08ead 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary") +load("//tools:defs.bzl", "bzl_library", "go_binary") package(licenses = ["notice"]) @@ -13,26 +13,8 @@ go_binary( deps = ["//tools/go_generics/globals"], ) -genrule( - name = "go_generics_tests", - srcs = glob(["generics_tests/**"]) + [":go_generics"], - outs = ["go_generics_tests.tgz"], - cmd = "tar -czvhf $@ $(SRCS)", -) - -genrule( - name = "go_generics_test_bundle", - srcs = [ - ":go_generics_tests.tgz", - ":go_generics_unittest.sh", - ], - outs = ["go_generics_test.sh"], - cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@", - executable = True, -) - -sh_test( - name = "go_generics_test", - size = "small", - srcs = ["go_generics_test.sh"], +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], ) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index 8c9995fd4..33329cf28 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -1,11 +1,24 @@ +"""Generics support via go_generics.""" + +TemplateInfo = provider( + fields = { + "types": "required types", + "opt_types": "optional types", + "consts": "required consts", + "opt_consts": "optional consts", + "deps": "package dependencies", + "file": "merged template", + }, +) + def _go_template_impl(ctx): - input = ctx.files.srcs + srcs = ctx.files.srcs output = ctx.outputs.out - args = ["-o=%s" % output.path] + [f.path for f in input] + args = ["-o=%s" % output.path] + [f.path for f in srcs] ctx.actions.run( - inputs = input, + inputs = srcs, outputs = [output], mnemonic = "GoGenericsTemplate", progress_message = "Building Go template %s" % ctx.label, @@ -13,14 +26,14 @@ def _go_template_impl(ctx): executable = ctx.executable._tool, ) - return struct( + return [TemplateInfo( types = ctx.attr.types, opt_types = ctx.attr.opt_types, consts = ctx.attr.consts, opt_consts = ctx.attr.opt_consts, deps = ctx.attr.deps, file = output, - ) + )] """ Generates a Go template from a set of Go files. @@ -43,7 +56,7 @@ go_template = rule( implementation = _go_template_impl, attrs = { "srcs": attr.label_list(mandatory = True, allow_files = True), - "deps": attr.label_list(allow_files = True), + "deps": attr.label_list(allow_files = True, cfg = "target"), "types": attr.string_list(), "opt_types": attr.string_list(), "consts": attr.string_list(), @@ -55,8 +68,14 @@ go_template = rule( }, ) +TemplateInstanceInfo = provider( + fields = { + "srcs": "source files", + }, +) + def _go_template_instance_impl(ctx): - template = ctx.attr.template + template = ctx.attr.template[TemplateInfo] output = ctx.outputs.out # Check that all required types are defined. @@ -81,20 +100,21 @@ def _go_template_instance_impl(ctx): # Build the argument list. args = ["-i=%s" % template.file.path, "-o=%s" % output.path] - args += ["-p=%s" % ctx.attr.package] + if ctx.attr.package: + args.append("-p=%s" % ctx.attr.package) if len(ctx.attr.prefix) > 0: - args += ["-prefix=%s" % ctx.attr.prefix] + args.append("-prefix=%s" % ctx.attr.prefix) if len(ctx.attr.suffix) > 0: - args += ["-suffix=%s" % ctx.attr.suffix] + args.append("-suffix=%s" % ctx.attr.suffix) args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()] args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()] args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()] if ctx.attr.anon: - args += ["-anon"] + args.append("-anon") ctx.actions.run( inputs = [template.file], @@ -105,9 +125,9 @@ def _go_template_instance_impl(ctx): executable = ctx.executable._tool, ) - return struct( - files = depset([output]), - ) + return [TemplateInstanceInfo( + srcs = [output], + )] """ Instantiates a Go template by replacing all generic types with concrete ones. @@ -125,14 +145,14 @@ Args: go_template_instance = rule( implementation = _go_template_instance_impl, attrs = { - "template": attr.label(mandatory = True, providers = ["types"]), + "template": attr.label(mandatory = True), "prefix": attr.string(), "suffix": attr.string(), "types": attr.string_dict(), "consts": attr.string_dict(), "imports": attr.string_dict(), "anon": attr.bool(mandatory = False, default = False), - "package": attr.string(mandatory = True), + "package": attr.string(mandatory = False), "out": attr.output(mandatory = True), "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")), }, diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_stmts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_types/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt deleted file mode 100644 index a5e9d26de..000000000 --- a/tools/go_generics/generics_tests/anon/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New -anon diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt deleted file mode 100644 index 4fb59dce8..000000000 --- a/tools/go_generics/generics_tests/consts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC" diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt deleted file mode 100644 index 87324be79..000000000 --- a/tools/go_generics/generics_tests/imports/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt deleted file mode 100644 index 9c8ecaada..000000000 --- a/tools/go_generics/generics_tests/remove_typedef/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=U diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt deleted file mode 100644 index 7832ef66f..000000000 --- a/tools/go_generics/generics_tests/simple/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh deleted file mode 100755 index 44b22db91..000000000 --- a/tools/go_generics/go_generics_unittest.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Bash "safe-mode": Treat command failures as fatal (even those that occur in -# pipes), and treat unset variables as errors. -set -eu -o pipefail - -# This file will be generated as a self-extracting shell script in order to -# eliminate the need for any runtime dependencies. The tarball at the end will -# include the go_generics binary, as well as a subdirectory named -# generics_tests. See the BUILD file for more information. -declare -r temp=$(mktemp -d) -function cleanup() { - rm -rf "${temp}" -} -# trap cleanup EXIT - -# Print message in "$1" then exit with status 1. -function die () { - echo "$1" 1>&2 - exit 1 -} - -# This prints the line number of __BUNDLE__ below, that should be the last line -# of this script. After that point, the concatenated archive will be the -# contents. -declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0` -tail -n+"${tgz}" $0 | tar -xzv -C "${temp}" - -# The target for the test. -declare -r binary="$(find ${temp} -type f -a -name go_generics)" -declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*" - -# Go through all test cases. -for f in ${input_dirs}; do - base=$(basename "${f}") - - # Run go_generics on the input file. - opts=$(head -n 1 ${f}/opts.txt) - out="${f}/output/generated.go" - expected="${f}/output/output.go" - ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\"" - - # Compare the outputs. - diff ${expected} ${out} - if [ $? -ne 0 ]; then - echo "Expected:" - cat ${expected} - echo "Actual:" - cat ${out} - die "Actual output is different from expected for test \"${base}\"" - fi -done - -echo "PASS" -exit 0 -__BUNDLE__ diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go index f6a331123..e0345500f 100644 --- a/tools/go_generics/go_merge/main.go +++ b/tools/go_generics/go_merge/main.go @@ -77,6 +77,7 @@ func main() { // Create a new declaration slice with all imports at the top, merging any // redundant imports. imports := make(map[string]*ast.ImportSpec) + var importNames []string // Keep imports in the original order to get deterministic output. var anonImports []*ast.ImportSpec for _, d := range f.Decls { if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT { @@ -98,6 +99,7 @@ func main() { } } else { imports[n] = i + importNames = append(importNames, n) } } } @@ -112,8 +114,8 @@ func main() { Lparen: token.NoPos + 1, Specs: make([]ast.Spec, 0, l), } - for _, i := range imports { - d.Specs = append(d.Specs, i) + for _, i := range importNames { + d.Specs = append(d.Specs, imports[i]) } for _, i := range anonImports { d.Specs = append(d.Specs, i) diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go index 148dc7216..90d3aa1e0 100644 --- a/tools/go_generics/imports.go +++ b/tools/go_generics/imports.go @@ -21,6 +21,7 @@ import ( "go/format" "go/parser" "go/token" + "sort" "strconv" "gvisor.dev/gvisor/tools/go_generics/globals" @@ -132,10 +133,17 @@ func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) { if len(importsUsed) == 0 { return nil, nil } + var names []string + for n := range importsUsed { + names = append(names, n) + } + // Sort the new imports for deterministic build outputs. + sort.Strings(names) // Create spec array for each new import. specs := make([]ast.Spec, 0, len(importsUsed)) - for _, i := range importsUsed { + for _, n := range names { + i := importsUsed[n] specs = append(specs, &ast.ImportSpec{ Name: &ast.Ident{Name: i.newName}, Path: &ast.BasicLit{Value: i.path}, diff --git a/tools/go_generics/tests/BUILD b/tools/go_generics/tests/BUILD new file mode 100644 index 000000000..7547a6b53 --- /dev/null +++ b/tools/go_generics/tests/BUILD @@ -0,0 +1,7 @@ +load("//tools:defs.bzl", "bzl_library") + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/go_generics/tests/all_stmts/BUILD b/tools/go_generics/tests/all_stmts/BUILD new file mode 100644 index 000000000..a4a7c775a --- /dev/null +++ b/tools/go_generics/tests/all_stmts/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "all_stmts", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/tests/all_stmts/input.go index 4791d1ff1..4791d1ff1 100644 --- a/tools/go_generics/generics_tests/all_stmts/input.go +++ b/tools/go_generics/tests/all_stmts/input.go diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/tests/all_stmts/output.go index a53d84535..a53d84535 100644 --- a/tools/go_generics/generics_tests/all_stmts/output/output.go +++ b/tools/go_generics/tests/all_stmts/output.go diff --git a/tools/go_generics/tests/all_types/BUILD b/tools/go_generics/tests/all_types/BUILD new file mode 100644 index 000000000..60b1fd314 --- /dev/null +++ b/tools/go_generics/tests/all_types/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "all_types", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/tests/all_types/input.go index 3575d02ec..6f85bbb69 100644 --- a/tools/go_generics/generics_tests/all_types/input.go +++ b/tools/go_generics/tests/all_types/input.go @@ -14,7 +14,9 @@ package tests -import "./lib" +import ( + "./lib" +) type T int diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/tests/all_types/lib/lib.go index 988786496..988786496 100644 --- a/tools/go_generics/generics_tests/all_types/lib/lib.go +++ b/tools/go_generics/tests/all_types/lib/lib.go diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/tests/all_types/output.go index 41fd147a1..c0bbebfe7 100644 --- a/tools/go_generics/generics_tests/all_types/output/output.go +++ b/tools/go_generics/tests/all_types/output.go @@ -14,7 +14,9 @@ package main -import "./lib" +import ( + "./lib" +) type newType struct { a Q diff --git a/tools/go_generics/tests/anon/BUILD b/tools/go_generics/tests/anon/BUILD new file mode 100644 index 000000000..ef24f4b25 --- /dev/null +++ b/tools/go_generics/tests/anon/BUILD @@ -0,0 +1,18 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "anon", + anon = True, + inputs = ["input.go"], + output = "output.go", + suffix = "New", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/tests/anon/input.go index 44086d522..44086d522 100644 --- a/tools/go_generics/generics_tests/anon/input.go +++ b/tools/go_generics/tests/anon/input.go diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/tests/anon/output.go index 160cddf79..7fa791853 100644 --- a/tools/go_generics/generics_tests/anon/output/output.go +++ b/tools/go_generics/tests/anon/output.go @@ -35,8 +35,8 @@ func (f FooNew) GetBar(name string) Q { func foobarNew() { a := BazNew{} - a.Q = 0 // should not be renamed, this is a limitation + a.Q = 0 b := otherpkg.UnrelatedType{} - b.Q = 0 // should not be renamed, this is a limitation + b.Q = 0 } diff --git a/tools/go_generics/tests/consts/BUILD b/tools/go_generics/tests/consts/BUILD new file mode 100644 index 000000000..fd7caccad --- /dev/null +++ b/tools/go_generics/tests/consts/BUILD @@ -0,0 +1,23 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "consts", + consts = { + "c1": "20", + "z": "600", + "v": "3.3", + "s": "\"def\"", + "A": "20", + "C": "100", + "S": "\"def\"", + "T": "\"ABC\"", + }, + inputs = ["input.go"], + output = "output.go", +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/tests/consts/input.go index 04b95fcc6..04b95fcc6 100644 --- a/tools/go_generics/generics_tests/consts/input.go +++ b/tools/go_generics/tests/consts/input.go diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/tests/consts/output.go index 18d316cc9..18d316cc9 100644 --- a/tools/go_generics/generics_tests/consts/output/output.go +++ b/tools/go_generics/tests/consts/output.go diff --git a/tools/go_generics/tests/defs.bzl b/tools/go_generics/tests/defs.bzl new file mode 100644 index 000000000..6277c3947 --- /dev/null +++ b/tools/go_generics/tests/defs.bzl @@ -0,0 +1,67 @@ +"""Generics tests.""" + +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") + +def _go_generics_test_impl(ctx): + runner = ctx.actions.declare_file(ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "exec diff --ignore-blank-lines --ignore-matching-lines=^[[:space:]]*// %s %s" % ( + ctx.files.template_output[0].short_path, + ctx.files.expected_output[0].short_path, + ), + "", + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + return [DefaultInfo( + executable = runner, + runfiles = ctx.runfiles( + files = ctx.files.template_output + ctx.files.expected_output, + collect_default = True, + collect_data = True, + ), + )] + +_go_generics_test = rule( + implementation = _go_generics_test_impl, + attrs = { + "template_output": attr.label(mandatory = True, allow_single_file = True), + "expected_output": attr.label(mandatory = True, allow_single_file = True), + }, + test = True, +) + +def go_generics_test(name, inputs, output, types = None, consts = None, **kwargs): + """Instantiates a generics test. + + Args: + name: the name of the test. + inputs: all the input files. + output: the output files. + types: the template types (dictionary). + consts: the template consts (dictionary). + **kwargs: additional arguments for the template_instance. + """ + if types == None: + types = dict() + if consts == None: + consts = dict() + go_template( + name = name + "_template", + srcs = inputs, + types = types.keys(), + consts = consts.keys(), + ) + go_template_instance( + name = name + "_output", + template = ":" + name + "_template", + out = name + "_output.go", + types = types, + consts = consts, + **kwargs + ) + _go_generics_test( + name = name + "_test", + template_output = name + "_output.go", + expected_output = output, + ) diff --git a/tools/go_generics/tests/imports/BUILD b/tools/go_generics/tests/imports/BUILD new file mode 100644 index 000000000..a86223d41 --- /dev/null +++ b/tools/go_generics/tests/imports/BUILD @@ -0,0 +1,24 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "imports", + consts = { + "n": "math.Uint32", + "m": "math.Uint64", + }, + imports = { + "sync": "sync", + "math": "mymathpath", + }, + inputs = ["input.go"], + output = "output.go", + types = { + "T": "sync.Mutex", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/tests/imports/input.go index 0f032c2a1..0f032c2a1 100644 --- a/tools/go_generics/generics_tests/imports/input.go +++ b/tools/go_generics/tests/imports/input.go diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/tests/imports/output.go index 2488ca58c..2488ca58c 100644 --- a/tools/go_generics/generics_tests/imports/output/output.go +++ b/tools/go_generics/tests/imports/output.go diff --git a/tools/go_generics/tests/remove_typedef/BUILD b/tools/go_generics/tests/remove_typedef/BUILD new file mode 100644 index 000000000..46457cec6 --- /dev/null +++ b/tools/go_generics/tests/remove_typedef/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "remove_typedef", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "U", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/tests/remove_typedef/input.go index cf632bae7..cf632bae7 100644 --- a/tools/go_generics/generics_tests/remove_typedef/input.go +++ b/tools/go_generics/tests/remove_typedef/input.go diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/tests/remove_typedef/output.go index d44fd8e1c..d44fd8e1c 100644 --- a/tools/go_generics/generics_tests/remove_typedef/output/output.go +++ b/tools/go_generics/tests/remove_typedef/output.go diff --git a/tools/go_generics/tests/simple/BUILD b/tools/go_generics/tests/simple/BUILD new file mode 100644 index 000000000..4b9265ea4 --- /dev/null +++ b/tools/go_generics/tests/simple/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "simple", + inputs = ["input.go"], + output = "output.go", + suffix = "New", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/tests/simple/input.go index 2a917f16c..2a917f16c 100644 --- a/tools/go_generics/generics_tests/simple/input.go +++ b/tools/go_generics/tests/simple/input.go diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/tests/simple/output.go index 6bfa0b25b..6bfa0b25b 100644 --- a/tools/go_generics/generics_tests/simple/output/output.go +++ b/tools/go_generics/tests/simple/output.go diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index be49cf9c8..f79defea7 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary") +load("//tools:defs.bzl", "bzl_library", "go_binary") licenses(["notice"]) @@ -17,3 +17,9 @@ config_setting( name = "marshal_config_verbose", values = {"define": "gomarshal=verbose"}, ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 4886efddf..d8045c295 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -3,20 +3,19 @@ This package implements the go_marshal utility. # Overview `go_marshal` is a code generation utility similar to `go_stateify` for -automatically generating code to marshal go data structures to memory. +marshalling go data structures to and from memory. `go_marshal` attempts to improve on `binary.Write` and the sentry's -`binary.Marshal` by moving the go runtime reflection necessary to marshal a -struct to compile-time. +`binary.Marshal` by moving the expensive use of reflection from runtime to +compile-time. -`go_marshal` automatically generates implementations for `abi.Marshallable` and -`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall -implementations) can directly invoke `safemem.Reader.ReadToBlocks` and -`safemem.Writer.WriteFromBlocks`. Data structures that require custom -serialization will have manual implementations for these interfaces. +`go_marshal` automatically generates implementations for `marshal.Marshallable` +interface. Data structures that require custom serialization can be accomodated +through a manual implementation this interface. Data structures can be flagged for code generation by adding a struct-level -comment `// +marshal`. +comment `// +marshal`. For additional details and options, see the documentation +for the `marshal.Marshallable` interface. # Usage @@ -76,7 +75,7 @@ intended for ABI structs, which have these additional restrictions: dependent native pointer size. - Fields must either be a primitive integer type (`byte`, - `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable. + `[u]int{8,16,32,64}`), or of a type that implements `marshal.Marshallable`. - `int` and `uint` fields are not allowed. Use an explicitly-sized numeric type. @@ -114,3 +113,18 @@ The following are some guidelines for modifying the `go_marshal` tool: - No runtime reflection in the code generated for the marshallable interface. The entire point of the tool is to avoid runtime reflection. The generated tests may use reflection. + +## Debugging + +To enable debugging output from the go-marshal tool, use one of the following +options, depending on how go-marshal is being invoked: + +- Pass `--define gomarshal=verbose` to the bazel command. Note that this can + generate a lot of output depending on what's being compiled, as this will + enable debugging for all packages built by the command. + +- Set `marshal_debug = True` on the top-level `go_library` BUILD rule. + +- Set `debug = True` on the `go_marshal` BUILD rule. + +- Pass `-debug` to the go-marshal tool invocation. diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index 323e33882..f44f83eab 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -4,11 +4,13 @@ def _go_marshal_impl(ctx): """Execute the go_marshal tool.""" output = ctx.outputs.lib output_test = ctx.outputs.test + output_test_unconditional = ctx.outputs.test_unconditional # Run the marshal command. args = ["-output=%s" % output.path] - args += ["-pkg=%s" % ctx.attr.package] - args += ["-output_test=%s" % output_test.path] + args.append("-pkg=%s" % ctx.attr.package) + args.append("-output_test=%s" % output_test.path) + args.append("-output_test_unconditional=%s" % output_test_unconditional.path) if ctx.attr.debug: args += ["-debug"] @@ -18,7 +20,7 @@ def _go_marshal_impl(ctx): args += [f.path for f in src.files.to_list()] ctx.actions.run( inputs = ctx.files.srcs, - outputs = [output, output_test], + outputs = [output, output_test, output_test_unconditional], mnemonic = "GoMarshal", progress_message = "go_marshal: %s" % ctx.label, arguments = args, @@ -48,6 +50,7 @@ go_marshal = rule( outputs = { "lib": "%{name}_unsafe.go", "test": "%{name}_test.go", + "test_unconditional": "%{name}_unconditional_test.go", }, ) @@ -56,7 +59,7 @@ marshal_deps = [ "//pkg/gohacks", "//pkg/safecopy", "//pkg/usermem", - "//tools/go_marshal/marshal", + "//pkg/marshal", ] # marshal_test_deps are required by test targets. diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 177013dbb..56fbcb5d2 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -38,8 +38,8 @@ import ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner", - "length", "limit", "ptr", "size", "src", "srcs", "task", "val", + "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx", + "inner", "length", "limit", "ptr", "size", "src", "srcs", "val", // All single-letter identifiers. } @@ -68,6 +68,8 @@ type Generator struct { output *os.File // Output file to write generated tests. outputTest *os.File + // Output file to write unconditionally generated tests. + outputTestUC *os.File // Package name for the generated file. pkg string // Set of extra packages to import in the generated file. @@ -75,7 +77,7 @@ type Generator struct { } // NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { +func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) @@ -84,12 +86,17 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G if err != nil { return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) } + fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open unconditional test output file %q: %v", out, err) + } g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - pkg: pkg, - imports: newImportTable(), + inputs: srcs, + output: f, + outputTest: fTest, + outputTestUC: fTestUC, + pkg: pkg, + imports: newImportTable(), } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as @@ -107,7 +114,7 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G g.imports.add("gvisor.dev/gvisor/pkg/gohacks") g.imports.add("gvisor.dev/gvisor/pkg/safecopy") g.imports.add("gvisor.dev/gvisor/pkg/usermem") - g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal") + g.imports.add("gvisor.dev/gvisor/pkg/marshal") return &g, nil } @@ -413,13 +420,13 @@ func (g *Generator) Run() error { for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. - for ref, _ := range impl.ms { + for ref := range impl.ms { ms[ref] = struct{}{} } impls = append(impls, impl) // Collect imports referenced by the generated code and add them to // the list of imports we need to copy to the generated code. - for name, _ := range impl.is { + for name := range impl.is { if !g.imports.markUsed(name) { panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) } @@ -454,6 +461,46 @@ func (g *Generator) Run() error { // source file. func (g *Generator) writeTests(ts []*testGenerator) error { var b sourceBuffer + + // Write the unconditional test file. This file is always compiled, + // regardless of what build tags were specified on the original input + // files. We use this file to guarantee we never end up with an empty test + // file, as that causes the build to fail with "no tests/benchmarks/examples + // found". + // + // There's no easy way to determine ahead of time if we'll end up with an + // empty build file since build constraints can arbitrarily cause some of + // the original types to be not defined. We also have no way to tell bazel + // to omit the entire test suite since the output files are already defined + // before go-marshal is called. + b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") + b.emit("package %s\n\n", g.pkg) + b.emit("func Example() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty, and ensures this package contains at\n") + b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n") + b.emit("// input package is marked marshallable, but emitting no testable entities \n") + b.emit("// results in a build failure.\n") + }) + b.emit("}\n") + if err := b.write(g.outputTestUC); err != nil { + return err + } + + // Now generate the real test file that contains the real types we + // processed. These need to be conditionally compiled according to the build + // tags, as the original types may not be defined under all build + // configurations. + + b.reset() + b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n\n") + } + b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { return err @@ -470,26 +517,6 @@ func (g *Generator) writeTests(ts []*testGenerator) error { } // Write test functions. - - // If we didn't generate any Marshallable implementations, we can't just - // emit an empty test file, since that causes the build to fail with "no - // tests/benchmarks/examples found". Unfortunately we can't signal bazel to - // omit the entire package since the outputs are already defined before - // go-marshal is called. If we'd otherwise emit an empty test suite, emit an - // empty example instead. - if len(ts) == 0 { - b.reset() - b.emit("func Example() {\n") - b.inIndent(func() { - b.emit("// This example is intentionally empty to ensure this file contains at least\n") - b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") - b.emit("// is marked marshallable, but emitting a test file with no entities results\n") - b.emit("// in a build failure.\n") - }) - b.emit("}\n") - return b.write(g.outputTest) - } - for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index e3c3dac63..36447b86b 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -43,8 +43,8 @@ type interfaceGenerator struct { // of t's interfaces. ms map[string]struct{} - // as records embedded fields in t that are potentially not packed. The key - // is the accessor for the field. + // as records fields in t that are potentially not packed. The key is the + // accessor for the field. as map[string]struct{} } @@ -224,7 +224,7 @@ func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) g.emit("// must live until the use above.\n") - g.emit("runtime.KeepAlive(%s)\n", ptrVar) + g.emit("runtime.KeepAlive(%s) // escapes: replaced by intrinsic.\n", ptrVar) } func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index 72ef03a22..7525b52da 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -102,11 +102,11 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) @@ -114,19 +114,19 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index 39f654ea8..7edaf666c 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -154,11 +154,11 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) @@ -166,19 +166,19 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) @@ -211,7 +211,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) g.emit("//go:nosplit\n") - g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) + g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) g.inIndent(func() { g.emit("count := len(dst)\n") g.emit("if count == 0 {\n") @@ -223,7 +223,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emitCastSliceToByteSlice("&dst", "buf", "size * count") - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive("dst") g.emit("return length, err\n") }) @@ -231,7 +231,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) g.emit("//go:nosplit\n") - g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) + g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) g.inIndent(func() { g.emit("count := len(src)\n") g.emit("if count == 0 {\n") @@ -243,7 +243,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emitCastSliceToByteSlice("&src", "buf", "size * count") - g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive("src") g.emit("return length, err\n") }) diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 9cd3c9579..fe76d3785 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -20,6 +20,7 @@ package gomarshal import ( "fmt" "go/ast" + "sort" "strings" ) @@ -40,6 +41,8 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { for accessor, _ := range g.as { cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) } + // Sort expressions for determinstic build outputs. + sort.Strings(cs) return strings.Join(cs, " && "), true } @@ -48,12 +51,6 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { // later. func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { forEachStructField(st, func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - forEachStructField(st, func(f *ast.Field) { fieldDispatcher{ primitive: func(_, t *ast.Ident) { g.validatePrimitiveNewtype(t) @@ -98,7 +95,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { var dynamicSizeTerms []string forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { + primitive: func(_, t *ast.Ident) { if size, dynamic := g.scalarSize(t); !dynamic { primitiveSize += size } else { @@ -106,13 +103,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) } }, - selector: func(n, tX, tSel *ast.Ident) { + selector: func(_, tX, tSel *ast.Ident) { tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) g.recordUsedImport(tX.Name) g.recordUsedMarshallable(tName) dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { lenExpr := g.arrayLenExpr(a) if size, dynamic := g.scalarSize(t); !dynamic { dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) @@ -268,6 +265,10 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } if thisPacked { g.recordUsedImport("safecopy") g.recordUsedImport("unsafe") @@ -277,16 +278,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) }) g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) + g.inIndent(fallback) g.emit("}\n") } else { g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) } } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) + fallback() } }) g.emit("}\n\n") @@ -294,6 +292,10 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } if thisPacked { g.recordUsedImport("safecopy") g.recordUsedImport("unsafe") @@ -303,16 +305,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) }) g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) + g.inIndent(fallback) g.emit("}\n") } else { g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) } } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) + fallback() } }) g.emit("}\n\n") @@ -321,13 +320,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) - g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") } if thisPacked { g.recordUsedImport("reflect") @@ -341,7 +340,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { // Fast serialization. g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") } else { @@ -354,9 +353,9 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") @@ -364,12 +363,12 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") g.emit("// partially unmarshalled struct.\n") g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) @@ -387,7 +386,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { // Fast deserialization. g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") } else { @@ -398,13 +397,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("length, err := w.Write(buf)\n") + g.emit("length, err := writer.Write(buf)\n") g.emit("return int64(length), err\n") } if thisPacked { @@ -419,7 +418,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { // Fast serialization. g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := w.Write(buf)\n") + g.emit("length, err := writer.Write(buf)\n") g.emitKeepAlive(g.r) g.emit("return int64(length), err\n") } else { @@ -440,7 +439,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.recordUsedImport("usermem") g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) - g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) + g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) g.inIndent(func() { g.emit("count := len(dst)\n") g.emit("if count == 0 {\n") @@ -452,8 +451,8 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(size * count)\n") - g.emit("length, err := task.CopyInBytes(addr, buf)\n\n") + g.emit("buf := cc.CopyScratchBuffer(size * count)\n") + g.emit("length, err := cc.CopyInBytes(addr, buf)\n\n") g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") g.emit("limit := length/size\n") @@ -463,8 +462,10 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, }) g.emit("}\n\n") - g.emit("// Handle any final partial object.\n") - g.emit("if length < size*count && length%size != 0 {\n") + g.emit("// Handle any final partial object. buf is guaranteed to be long enough for the\n") + g.emit("// final element, but may not contain valid data for the entire range. This may\n") + g.emit("// result in unmarshalling zero values for some parts of the object.\n") + g.emit("if length%size != 0 {\n") g.inIndent(func() { g.emit("idx := limit\n") g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") @@ -485,7 +486,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, // Fast deserialization. g.emitCastSliceToByteSlice("&dst", "buf", "size * count") - g.emit("length, err := task.CopyInBytes(addr, buf)\n") + g.emit("length, err := cc.CopyInBytes(addr, buf)\n") g.emitKeepAlive("dst") g.emit("return length, err\n") } else { @@ -495,7 +496,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("}\n\n") g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) - g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) + g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) g.inIndent(func() { g.emit("count := len(src)\n") g.emit("if count == 0 {\n") @@ -507,13 +508,13 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("buf := cc.CopyScratchBuffer(size * count)\n") g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") }) g.emit("}\n") - g.emit("return task.CopyOutBytes(addr, buf)\n") + g.emit("return cc.CopyOutBytes(addr, buf)\n") } if thisPacked { g.recordUsedImport("reflect") @@ -527,7 +528,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, // Fast serialization. g.emitCastSliceToByteSlice("&src", "buf", "size * count") - g.emit("length, err := task.CopyOutBytes(addr, buf)\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf)\n") g.emitKeepAlive("src") g.emit("return length, err\n") } else { diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index d94314302..6a42691cd 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -79,7 +79,7 @@ type fieldDispatcher struct { } // Precondition: All dispatch callbacks that will be invoked must be -// provided. Embedded fields are not allowed, len(f.Names) >= 1. +// provided. func (fd fieldDispatcher) dispatch(f *ast.Field) { // Each field declaration may actually be multiple declarations of the same // type. For example, consider: @@ -88,12 +88,24 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { // x, y, z int // } // - // We invoke the call-backs once per such instance. Embedded fields are not - // allowed, and results in a panic. + // We invoke the call-backs once per such instance. + + // Handle embedded fields. Embedded fields have no names, but can be + // referenced by the type name. if len(f.Names) < 1 { - panic("Precondition not met: attempted to dispatch on embedded field") + switch v := f.Type.(type) { + case *ast.Ident: + fd.primitive(v, v) + case *ast.SelectorExpr: + fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel) + default: + // Note: Arrays can't be embedded, which is handled here. + panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type)) + } + return } + // Non-embedded field. for _, name := range f.Names { switch v := f.Type.(type) { case *ast.Ident: diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go index f74be5c29..6e4a3e8c4 100644 --- a/tools/go_marshal/main.go +++ b/tools/go_marshal/main.go @@ -31,10 +31,11 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") - output = flag.String("output", "", "output file") - outputTest = flag.String("output_test", "", "output file for tests") - imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + outputTestUnconditional = flag.String("output_test_unconditional", "", "output file for unconditional tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") ) func main() { @@ -61,7 +62,7 @@ func main() { // as an import. extraImports = strings.Split(*imports, ",") } - g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *outputTestUnconditional, *pkg, extraImports) if err != nil { panic(err) } diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go deleted file mode 100644 index ebcf130ae..000000000 --- a/tools/go_marshal/primitive/primitive.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package primitive defines marshal.Marshallable implementations for primitive -// types. -package primitive - -import ( - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" -) - -// Int16 is a marshal.Marshallable implementation for int16. -// -// +marshal slice:Int16Slice:inner -type Int16 int16 - -// Uint16 is a marshal.Marshallable implementation for uint16. -// -// +marshal slice:Uint16Slice:inner -type Uint16 uint16 - -// Int32 is a marshal.Marshallable implementation for int32. -// -// +marshal slice:Int32Slice:inner -type Int32 int32 - -// Uint32 is a marshal.Marshallable implementation for uint32. -// -// +marshal slice:Uint32Slice:inner -type Uint32 uint32 - -// Int64 is a marshal.Marshallable implementation for int64. -// -// +marshal slice:Int64Slice:inner -type Int64 int64 - -// Uint64 is a marshal.Marshallable implementation for uint64. -// -// +marshal slice:Uint64Slice:inner -type Uint64 uint64 - -// Below, we define some convenience functions for marshalling primitive types -// using the newtypes above, without requiring superfluous casts. - -// 16-bit integers - -// CopyInt16In is a convenient wrapper for copying in an int16 from the task's -// memory. -func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) { - var buf Int16 - n, err := buf.CopyIn(task, addr) - if err != nil { - return n, err - } - *dst = int16(buf) - return n, nil -} - -// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's -// memory. -func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) { - srcP := Int16(src) - return srcP.CopyOut(task, addr) -} - -// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's -// memory. -func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) { - var buf Uint16 - n, err := buf.CopyIn(task, addr) - if err != nil { - return n, err - } - *dst = uint16(buf) - return n, nil -} - -// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's -// memory. -func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) { - srcP := Uint16(src) - return srcP.CopyOut(task, addr) -} - -// 32-bit integers - -// CopyInt32In is a convenient wrapper for copying in an int32 from the task's -// memory. -func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) { - var buf Int32 - n, err := buf.CopyIn(task, addr) - if err != nil { - return n, err - } - *dst = int32(buf) - return n, nil -} - -// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's -// memory. -func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) { - srcP := Int32(src) - return srcP.CopyOut(task, addr) -} - -// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's -// memory. -func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) { - var buf Uint32 - n, err := buf.CopyIn(task, addr) - if err != nil { - return n, err - } - *dst = uint32(buf) - return n, nil -} - -// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's -// memory. -func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) { - srcP := Uint32(src) - return srcP.CopyOut(task, addr) -} - -// 64-bit integers - -// CopyInt64In is a convenient wrapper for copying in an int64 from the task's -// memory. -func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) { - var buf Int64 - n, err := buf.CopyIn(task, addr) - if err != nil { - return n, err - } - *dst = int64(buf) - return n, nil -} - -// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's -// memory. -func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) { - srcP := Int64(src) - return srcP.CopyOut(task, addr) -} - -// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's -// memory. -func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) { - var buf Uint64 - n, err := buf.CopyIn(task, addr) - if err != nil { - return n, err - } - *dst = uint64(buf) - return n, nil -} - -// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's -// memory. -func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) { - srcP := Uint64(src) - return srcP.CopyOut(task, addr) -} diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index 2fbcc8a03..4b27773c2 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -35,10 +35,10 @@ go_test( srcs = ["marshal_test.go"], deps = [ ":test", + "//pkg/marshal", "//pkg/syserror", "//pkg/usermem", "//tools/go_marshal/analysis", - "//tools/go_marshal/marshal", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD index f74e6ffae..2981ef196 100644 --- a/tools/go_marshal/test/escape/BUILD +++ b/tools/go_marshal/test/escape/BUILD @@ -7,8 +7,8 @@ go_library( testonly = 1, srcs = ["escape.go"], deps = [ + "//pkg/marshal", "//pkg/usermem", - "//tools/go_marshal/marshal", "//tools/go_marshal/test", ], ) diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go index 6a46ddbf8..7f62b0a2b 100644 --- a/tools/go_marshal/test/escape/escape.go +++ b/tools/go_marshal/test/escape/escape.go @@ -15,34 +15,34 @@ package escape import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" "gvisor.dev/gvisor/tools/go_marshal/test" ) -// dummyTask implements marshal.Task. -type dummyTask struct { +// dummyCopyContext implements marshal.CopyContext. +type dummyCopyContext struct { } -func (*dummyTask) CopyScratchBuffer(size int) []byte { +func (*dummyCopyContext) CopyScratchBuffer(size int) []byte { return make([]byte, size) } -func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { +func (*dummyCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { return len(b), nil } -func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { +func (*dummyCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { return len(b), nil } -func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { +func (t *dummyCopyContext) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { buf := t.CopyScratchBuffer(marshallable.SizeBytes()) marshallable.MarshalBytes(buf) t.CopyOutBytes(addr, buf) } -func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { +func (t *dummyCopyContext) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { buf := t.CopyScratchBuffer(marshallable.SizeBytes()) marshallable.MarshalUnsafe(buf) t.CopyOutBytes(addr, buf) @@ -50,21 +50,22 @@ func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marsha // +checkescape:all //go:nosplit -func doCopyIn(t *dummyTask) { +func doCopyIn(t *dummyCopyContext) { var stat test.Stat stat.CopyIn(t, usermem.Addr(0xf000ba12)) } // +checkescape:all //go:nosplit -func doCopyOut(t *dummyTask) { +func doCopyOut(t *dummyCopyContext) { var stat test.Stat stat.CopyOut(t, usermem.Addr(0xf000ba12)) } // +mustescape:builtin // +mustescape:stack -func doMarshalBytesDirect(t *dummyTask) { +//go:nosplit +func doMarshalBytesDirect(t *dummyCopyContext) { var stat test.Stat buf := t.CopyScratchBuffer(stat.SizeBytes()) stat.MarshalBytes(buf) @@ -73,7 +74,8 @@ func doMarshalBytesDirect(t *dummyTask) { // +mustescape:builtin // +mustescape:stack -func doMarshalUnsafeDirect(t *dummyTask) { +//go:nosplit +func doMarshalUnsafeDirect(t *dummyCopyContext) { var stat test.Stat buf := t.CopyScratchBuffer(stat.SizeBytes()) stat.MarshalUnsafe(buf) @@ -82,14 +84,16 @@ func doMarshalUnsafeDirect(t *dummyTask) { // +mustescape:local,heap // +mustescape:stack -func doMarshalBytesViaMarshallable(t *dummyTask) { +//go:nosplit +func doMarshalBytesViaMarshallable(t *dummyCopyContext) { var stat test.Stat t.MarshalBytes(usermem.Addr(0xf000ba12), &stat) } // +mustescape:local,heap // +mustescape:stack -func doMarshalUnsafeViaMarshallable(t *dummyTask) { +//go:nosplit +func doMarshalUnsafeViaMarshallable(t *dummyCopyContext) { var stat test.Stat t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) } diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go index 16829ee45..a00f9a684 100644 --- a/tools/go_marshal/test/marshal_test.go +++ b/tools/go_marshal/test/marshal_test.go @@ -27,22 +27,22 @@ import ( "unsafe" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" - "gvisor.dev/gvisor/tools/go_marshal/marshal" "gvisor.dev/gvisor/tools/go_marshal/test" ) var simulatedErr error = syserror.EFAULT -// mockTask implements marshal.Task. -type mockTask struct { +// mockCopyContext implements marshal.CopyContext. +type mockCopyContext struct { taskMem usermem.BytesIO } // populate fills the task memory with the contents of val. -func (t *mockTask) populate(val interface{}) { +func (t *mockCopyContext) populate(val interface{}) { var buf bytes.Buffer // Use binary.Write so we aren't testing go-marshal against its own // potentially buggy implementation. @@ -52,7 +52,7 @@ func (t *mockTask) populate(val interface{}) { t.taskMem.Bytes = buf.Bytes() } -func (t *mockTask) setLimit(n int) { +func (t *mockCopyContext) setLimit(n int) { if len(t.taskMem.Bytes) < n { grown := make([]byte, n) copy(grown, t.taskMem.Bytes) @@ -62,22 +62,22 @@ func (t *mockTask) setLimit(n int) { t.taskMem.Bytes = t.taskMem.Bytes[:n] } -// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer. -func (t *mockTask) CopyScratchBuffer(size int) []byte { +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (t *mockCopyContext) CopyScratchBuffer(size int) []byte { return make([]byte, size) } -// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. The implementation // completely ignores the target address and stores a copy of b in its // internally buffer, overriding any previous contents. -func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { +func (t *mockCopyContext) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{}) } -// CopyInBytes implements marshal.Task.CopyInBytes. The implementation +// CopyInBytes implements marshal.CopyContext.CopyInBytes. The implementation // completely ignores the source address and always fills b from the begining of // its internal buffer. -func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { +func (t *mockCopyContext) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{}) } @@ -171,11 +171,11 @@ func compareMemory(t *testing.T, expected, actual []byte, n int) { // dst. The task signals an error at limit bytes during copy-in, which should // result in a truncated unmarshalling. func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { - var task mockTask - task.populate(src) - task.setLimit(limit) + var cc mockCopyContext + cc.populate(src) + cc.setLimit(limit) - n, err := dst.CopyIn(&task, usermem.Addr(0)) + n, err := dst.CopyIn(&cc, usermem.Addr(0)) if n != limit { t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -202,10 +202,10 @@ func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { // limitedCopyOut marshals src to task memory. The task signals an error at // limit bytes during copy-out, which should result in a truncated marshalling. func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { - var task mockTask - task.setLimit(limit) + var cc mockCopyContext + cc.setLimit(limit) - n, err := src.CopyOut(&task, usermem.Addr(0)) + n, err := src.CopyOut(&cc, usermem.Addr(0)) if n != limit { t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -215,7 +215,7 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { expectedMem := unsafeMemory(src) defer runtime.KeepAlive(src) - actualMem := task.taskMem.Bytes + actualMem := cc.taskMem.Bytes compareMemory(t, expectedMem, actualMem, n) } @@ -223,10 +223,10 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { // copyOutN marshals src to task memory, requesting the marshalling to be // limited to limit bytes. func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { - var task mockTask - task.setLimit(limit) + var cc mockCopyContext + cc.setLimit(limit) - n, err := src.CopyOutN(&task, usermem.Addr(0), limit) + n, err := src.CopyOutN(&cc, usermem.Addr(0), limit) if err != nil { t.Errorf("CopyOut returned unexpected error: %v", err) } @@ -236,7 +236,7 @@ func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { expectedMem := unsafeMemory(src) defer runtime.KeepAlive(src) - actualMem := task.taskMem.Bytes + actualMem := cc.taskMem.Bytes t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:]) t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:]) @@ -303,20 +303,20 @@ func TestLimitedMarshalling(t *testing.T) { func TestLimitedSliceMarshalling(t *testing.T) { types := []struct { arrayPtrType reflect.Type - copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error) - copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error) + copySliceIn func(cc marshal.CopyContext, addr usermem.Addr, dstSlice interface{}) (int, error) + copySliceOut func(cc marshal.CopyContext, addr usermem.Addr, srcSlice interface{}) (int, error) unsafeMemory func(arrPtr interface{}) []byte }{ // Packed types. { reflect.TypeOf((*[20]test.Stat)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[20]test.Stat)[:] - return test.CopyStatSliceIn(task, addr, slice) + return test.CopyStatSliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[20]test.Stat)[:] - return test.CopyStatSliceOut(task, addr, slice) + return test.CopyStatSliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[20]test.Stat)[:] @@ -325,13 +325,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[1]test.Stat)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[1]test.Stat)[:] - return test.CopyStatSliceIn(task, addr, slice) + return test.CopyStatSliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[1]test.Stat)[:] - return test.CopyStatSliceOut(task, addr, slice) + return test.CopyStatSliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[1]test.Stat)[:] @@ -340,13 +340,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[5]test.SignalSetAlias)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[5]test.SignalSetAlias)[:] - return test.CopySignalSetAliasSliceIn(task, addr, slice) + return test.CopySignalSetAliasSliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[5]test.SignalSetAlias)[:] - return test.CopySignalSetAliasSliceOut(task, addr, slice) + return test.CopySignalSetAliasSliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[5]test.SignalSetAlias)[:] @@ -356,13 +356,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { // Non-packed types. { reflect.TypeOf((*[20]test.Type1)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[20]test.Type1)[:] - return test.CopyType1SliceIn(task, addr, slice) + return test.CopyType1SliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[20]test.Type1)[:] - return test.CopyType1SliceOut(task, addr, slice) + return test.CopyType1SliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[20]test.Type1)[:] @@ -371,13 +371,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[1]test.Type1)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[1]test.Type1)[:] - return test.CopyType1SliceIn(task, addr, slice) + return test.CopyType1SliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[1]test.Type1)[:] - return test.CopyType1SliceOut(task, addr, slice) + return test.CopyType1SliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[1]test.Type1)[:] @@ -386,13 +386,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[7]test.Type8)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[7]test.Type8)[:] - return test.CopyType8SliceIn(task, addr, slice) + return test.CopyType8SliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[7]test.Type8)[:] - return test.CopyType8SliceOut(task, addr, slice) + return test.CopyType8SliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[7]test.Type8)[:] @@ -439,11 +439,11 @@ func TestLimitedSliceMarshalling(t *testing.T) { limit += elem.SizeBytes() / 2 analysis.RandomizeValue(expected) - var task mockTask - task.populate(expected) - task.setLimit(limit) + var cc mockCopyContext + cc.populate(expected) + cc.setLimit(limit) - n, err := tt.copySliceIn(&task, usermem.Addr(0), actual) + n, err := tt.copySliceIn(&cc, usermem.Addr(0), actual) if n != limit { t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -493,11 +493,11 @@ func TestLimitedSliceMarshalling(t *testing.T) { limit += elem.SizeBytes() / 2 analysis.RandomizeValue(expected) - var task mockTask - task.populate(expected) - task.setLimit(limit) + var cc mockCopyContext + cc.populate(expected) + cc.setLimit(limit) - n, err := tt.copySliceOut(&task, usermem.Addr(0), expected) + n, err := tt.copySliceOut(&cc, usermem.Addr(0), expected) if n != limit { t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -507,7 +507,7 @@ func TestLimitedSliceMarshalling(t *testing.T) { expectedMem := tt.unsafeMemory(expected) defer runtime.KeepAlive(expected) - actualMem := task.taskMem.Bytes + actualMem := cc.taskMem.Bytes compareMemory(t, expectedMem, actualMem, n) }) diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index f75ca1b7f..d9e9f341b 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -174,3 +174,27 @@ type Type9 struct { x int64 y [sizeA]int32 } + +// Type10Embed is a test data type which is be embedded into another type. +// +// +marshal +type Type10Embed struct { + x int64 +} + +// Type10 is a test data type which contains an embedded struct. +// +// +marshal +type Type10 struct { + Type10Embed + y int64 +} + +// Type11 is a test data type which contains an embedded struct from an external +// package. +// +// +marshal +type Type11 struct { + ex.External + y int64 +} diff --git a/tools/go_mod.sh b/tools/go_mod.sh deleted file mode 100755 index 84b779d6d..000000000 --- a/tools/go_mod.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -eo pipefail - -# Build the :gopath target. -bazel build //:gopath -declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" - -# Copy go.mod and execute the command. -cp -a go.mod go.sum "${gopathdir}" -(cd "${gopathdir}" && go mod "$@") -cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . - -# Cleanup the WORKSPACE file. -bazel run //:gazelle -- update-repos -from_file=go.mod diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index 503cdf2e5..913558b4e 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary") +load("//tools:defs.bzl", "bzl_library", "go_binary") package(licenses = ["notice"]) @@ -8,3 +8,9 @@ go_binary( visibility = ["//:sandbox"], deps = ["//tools/tags"], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/installers/BUILD b/tools/installers/BUILD index caa7b1983..13d3cc5e0 100644 --- a/tools/installers/BUILD +++ b/tools/installers/BUILD @@ -5,15 +5,12 @@ package( licenses = ["notice"], ) -filegroup( - name = "runsc", - srcs = ["//runsc"], -) - sh_binary( name = "head", srcs = ["head.sh"], - data = [":runsc"], + data = [ + "//runsc", + ], ) sh_binary( @@ -30,6 +27,15 @@ sh_binary( ) sh_binary( + name = "containerd", + srcs = ["containerd.sh"], +) + +sh_binary( name = "shim", srcs = ["shim.sh"], + data = [ + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", + ], ) diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh new file mode 100755 index 000000000..6b7bb261c --- /dev/null +++ b/tools/installers/containerd.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +declare -r CONTAINERD_VERSION=${CONTAINERD_VERSION:-1.3.0} +declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')" +declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')" + +# Default to an older version for crictl for containerd <= 1.2. +if [[ "${CONTAINERD_MAJOR}" -eq 1 ]] && [[ "${CONTAINERD_MINOR}" -le 2 ]]; then + declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.13.0} +else + declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.18.0} +fi + +# Helper for Go packages below. +install_helper() { + PACKAGE="${1}" + TAG="${2}" + + # Clone the repository. + mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ + git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" + + # Checkout and build the repository. + (cd "${GOPATH}"/src/"${PACKAGE}" && \ + git checkout "${TAG}" && \ + make && \ + make install) +} + +# Install dependencies for the crictl tests. +while true; do + if (apt-get update && apt-get install -y \ + btrfs-tools \ + libseccomp-dev); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Install containerd & cri-tools. +declare -rx GOPATH=$(mktemp -d --tmpdir gopathXXXXX) +install_helper github.com/containerd/containerd "v${CONTAINERD_VERSION}" "${GOPATH}" +install_helper github.com/kubernetes-sigs/cri-tools "v${CRITOOLS_VERSION}" "${GOPATH}" + +# Configure containerd-shim. +# +# Note that for versions <= 1.1 the legacy shim must be installed in /usr/bin, +# which should align with the installer script in head.sh (or master.sh). +if [[ "${CONTAINERD_MAJOR}" -le 1 ]] && [[ "${CONTAINERD_MINOR}" -lt 2 ]]; then + declare -r shim_config_path=/etc/containerd/gvisor-containerd-shim.toml + mkdir -p $(dirname ${shim_config_path}) + cat > ${shim_config_path} <<-EOF + runc_shim = "/usr/bin/containerd-shim" + +[runsc_config] + debug = "true" + debug-log = "/tmp/runsc-logs/" + strace = "true" + file-access = "shared" +EOF +fi + +# Configure CNI. +(cd "${GOPATH}" && src/github.com/containerd/containerd/script/setup/install-cni) +cat <<EOF | sudo tee /etc/cni/net.d/10-bridge.conf +{ + "cniVersion": "0.3.1", + "name": "bridge", + "type": "bridge", + "bridge": "cnio0", + "isGateway": true, + "ipMasq": true, + "ipam": { + "type": "host-local", + "ranges": [ + [{"subnet": "10.200.0.0/24"}] + ], + "routes": [{"dst": "0.0.0.0/0"}] + } +} +EOF +cat <<EOF | sudo tee /etc/cni/net.d/99-loopback.conf +{ + "cniVersion": "0.3.1", + "type": "loopback" +} +EOF + +# Configure crictl. +cat <<EOF | sudo tee /etc/crictl.yaml +runtime-endpoint: unix:///run/containerd/containerd.sock +EOF + +# Cleanup. +rm -rf "${GOPATH}" diff --git a/tools/installers/head.sh b/tools/installers/head.sh index 7fc566ebd..a613fcb5b 100755 --- a/tools/installers/head.sh +++ b/tools/installers/head.sh @@ -15,7 +15,13 @@ # limitations under the License. # Install our runtime. -$(find . -executable -type f -name runsc) install +runfiles=. +if [[ -d "$0.runfiles" ]]; then + runfiles="$0.runfiles" +fi +$(find -L "${runfiles}" -executable -type f -name runsc) install # Restart docker. -service docker restart || true +if service docker status 2>/dev/null; then + service docker restart +fi diff --git a/tools/installers/shim.sh b/tools/installers/shim.sh index f7dd790a1..8153ce283 100755 --- a/tools/installers/shim.sh +++ b/tools/installers/shim.sh @@ -14,11 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Reinstall the latest containerd shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -mv ${shim_path} /usr/local/bin/gvisor-containerd-shim +# Install all the shims. +# +# Note that containerd looks at the current executable directory +# in order to find the shim binary. So we need to check in order +# of preference. The local containerd installer will install to +# /usr/local, so we use that first. +if [[ -x /usr/local/bin/containerd ]]; then + containerd_install_dir=/usr/local/bin +else + containerd_install_dir=/usr/bin +fi +runfiles=. +if [[ -d "$0.runfiles" ]]; then + runfiles="$0.runfiles" +fi +find -L "${runfiles}" -executable -type f -name containerd-shim-runsc-v1 -exec cp -L {} "${containerd_install_dir}" \; +find -L "${runfiles}" -executable -type f -name gvisor-containerd-shim -exec cp -L {} "${containerd_install_dir}" \; diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD deleted file mode 100644 index 4ef1a3124..000000000 --- a/tools/issue_reviver/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "issue_reviver", - srcs = ["main.go"], - deps = [ - "//tools/issue_reviver/github", - "//tools/issue_reviver/reviver", - ], -) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD deleted file mode 100644 index da4133472..000000000 --- a/tools/issue_reviver/github/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "github", - srcs = ["github.go"], - visibility = [ - "//tools/issue_reviver:__subpackages__", - ], - deps = [ - "//tools/issue_reviver/reviver", - "@com_github_google_go-github//github:go_default_library", - "@org_golang_x_oauth2//:go_default_library", - ], -) diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go deleted file mode 100644 index 47c796b8a..000000000 --- a/tools/issue_reviver/main.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package main is the entry point for issue_reviver. -package main - -import ( - "flag" - "fmt" - "io/ioutil" - "os" - "strings" - - "gvisor.dev/gvisor/tools/issue_reviver/github" - "gvisor.dev/gvisor/tools/issue_reviver/reviver" -) - -var ( - owner string - repo string - tokenFile string - path string - dryRun bool -) - -// Keep the options simple for now. Supports only a single path and repo. -func init() { - flag.StringVar(&owner, "owner", "", "Github project org/owner to look for issues") - flag.StringVar(&repo, "repo", "", "Github repo to look for issues") - flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github") - flag.StringVar(&path, "path", ".", "Path to scan for TODOs") - flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues") -} - -func main() { - // Set defaults from the environment. - repository := os.Getenv("GITHUB_REPOSITORY") - if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 { - owner = parts[0] - repo = parts[1] - } - - // Parse flags. - flag.Parse() - - // Check for mandatory parameters. - if len(owner) == 0 { - fmt.Println("missing --owner option.") - flag.Usage() - os.Exit(1) - } - if len(repo) == 0 { - fmt.Println("missing --repo option.") - flag.Usage() - os.Exit(1) - } - if len(path) == 0 { - fmt.Println("missing --path option.") - flag.Usage() - os.Exit(1) - } - - // The access token may be passed as a file so it doesn't show up in - // command line arguments. It also may be provided through the - // environment to faciliate use through GitHub's CI system. - token := os.Getenv("GITHUB_TOKEN") - if len(tokenFile) != 0 { - bytes, err := ioutil.ReadFile(tokenFile) - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - token = string(bytes) - } - - bugger, err := github.NewBugger(token, owner, repo, dryRun) - if err != nil { - fmt.Fprintln(os.Stderr, "Error getting github issues:", err) - os.Exit(1) - } - rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) - if errs := rev.Run(); len(errs) > 0 { - fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) - for _, err := range errs { - fmt.Fprintf(os.Stderr, "\t%v\n", err) - } - os.Exit(1) - } -} diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD deleted file mode 100644 index d262932bd..000000000 --- a/tools/issue_reviver/reviver/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "reviver", - srcs = ["reviver.go"], - visibility = [ - "//tools/issue_reviver:__subpackages__", - ], -) - -go_test( - name = "reviver_test", - size = "small", - srcs = ["reviver_test.go"], - library = ":reviver", -) diff --git a/tools/make_apt.sh b/tools/make_apt.sh index 3fb1066e5..13c5edd76 100755 --- a/tools/make_apt.sh +++ b/tools/make_apt.sh @@ -54,18 +54,22 @@ declare -r release="${root}/dists/${suite}" mkdir -p "${release}" # Create a temporary keyring, and ensure it is cleaned up. +# Using separate homedir allows us to install apt repositories multiple times +# using the same key. This is a limitation in GnuPG pre-2.1. declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg) +declare -r homedir=$(mktemp -d /tmp/homedirXXXXXX) +declare -r gpg_opts=("--no-default-keyring" "--secret-keyring" "${keyring}" "--homedir" "${homedir}") cleanup() { - rm -f "${keyring}" + rm -rf "${keyring}" "${homedir}" } trap cleanup EXIT # We attempt the import twice because the first one will fail if the public key # is not found. This isn't actually a failure for us, because we don't require -# the public (this may be stored separately). The second import will succeed +# the public key (this may be stored separately). The second import will succeed # because, in reality, the first import succeeded and it's a no-op. -gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" || \ - gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" +gpg "${gpg_opts[@]}" --import "${private_key}" || \ + gpg "${gpg_opts[@]}" --import "${private_key}" # Copy the packages into the root. for pkg in "$@"; do @@ -100,7 +104,8 @@ for pkg in "$@"; do cp -a "${pkg}" "${target}" chmod 0644 "${target}" if [[ "${ext}" == "deb" ]]; then - dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${target}" + # We use [*] here to expand the gpg_opts array into a single shell-word. + dpkg-sig -g "${gpg_opts[*]}" --sign builder "${target}" fi done @@ -135,5 +140,5 @@ rm "${release}"/apt.conf # Sign the release. declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512") (cd "${release}" && rm -f Release.gpg InRelease) -(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release) -(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release) +(cd "${release}" && gpg "${gpg_opts[@]}" --clearsign "${digest_opts[@]}" -o InRelease Release) +(cd "${release}" && gpg "${gpg_opts[@]}" -abs "${digest_opts[@]}" -o Release.gpg Release) diff --git a/tools/make_release.sh b/tools/make_release.sh index b1cdd47b0..9137dd9bb 100755 --- a/tools/make_release.sh +++ b/tools/make_release.sh @@ -43,8 +43,7 @@ install_raw() { # Copy the raw file & generate a sha512sum. name=$(basename "${binary}") cp -f "${binary}" "${root}/$1" - sha512sum "${root}/$1/${name}" | \ - awk "{print $$1 \" ${name}\"}" > "${root}/$1/${name}.sha512" + (cd "${root}/$1" && sha512sum "${name}" > "${name}.sha512") done } diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index c21b09511..9f1fcd9c7 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -1,7 +1,18 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "bzl_library", "go_library") +load("//tools/nogo:defs.bzl", "nogo_dump_tool", "nogo_stdlib") package(licenses = ["notice"]) +nogo_dump_tool( + name = "dump_tool", + visibility = ["//visibility:public"], +) + +nogo_stdlib( + name = "stdlib", + visibility = ["//visibility:public"], +) + go_library( name = "nogo", srcs = [ @@ -16,7 +27,6 @@ go_library( deps = [ "//tools/checkescape", "//tools/checkunsafe", - "//tools/nogo/data", "@org_golang_x_tools//go/analysis:go_tool_library", "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library", "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library", @@ -47,3 +57,9 @@ go_library( "@org_golang_x_tools//go/gcexportdata:go_tool_library", ], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 1c0d08661..39c2ae418 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -26,11 +26,18 @@ var ( // and should not have any special prefix applied. internalPrefix = fmt.Sprintf("^") + // internalDefault is applied when no paths are provided. + internalDefault = fmt.Sprintf("%s/.*", notPath("external")) + // externalPrefix is external workspace packages. externalPrefix = "^external/" ) // findStdPkg needs to find the bundled standard library packages. -func findStdPkg(path, GOOS, GOARCH string) (io.ReadCloser, error) { +func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) { + if path == "C" { + // Cgo builds cannot be analyzed. Skip. + return nil, ErrSkip + } return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path)) } diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD index e2d76cd5c..21ba2c306 100644 --- a/tools/nogo/check/BUILD +++ b/tools/nogo/check/BUILD @@ -7,6 +7,7 @@ package(licenses = ["notice"]) go_binary( name = "check", srcs = ["main.go"], + nogo = False, visibility = ["//visibility:public"], deps = ["//tools/nogo"], ) diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 6958fca69..cfe7b4aa4 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -84,6 +84,14 @@ var analyzerConfig = map[*analysis.Analyzer]matcher{ externalExcluded( ".*protobuf/.*.go", // Bad conversions. ".*flate/huffman_bit_writer.go", // Bad conversion. + + // Runtime internal violations. + ".*reflect/value.go", + ".*encoding/xml/xml.go", + ".*runtime/pprof/internal/profile/proto.go", + ".*fmt/scan.go", + ".*go/types/conversions.go", + ".*golang.org/x/net/dns/dnsmessage/message.go", ), ), shadow.Analyzer: disableMatches(), // Disabled for now. @@ -114,3 +122,8 @@ var analyzerConfig = map[*analysis.Analyzer]matcher{ checkescape.Analyzer: internalMatches(), checkunsafe.Analyzer: internalMatches(), } + +var escapesConfig = map[*analysis.Analyzer]matcher{ + // Informational only: include all packages. + checkescape.EscapeAnalyzer: alwaysMatches(), +} diff --git a/tools/nogo/data/BUILD b/tools/nogo/data/BUILD deleted file mode 100644 index b7564cc44..000000000 --- a/tools/nogo/data/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "data", - srcs = ["data.go"], - nogo = False, - visibility = ["//tools:__subpackages__"], -) diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 6560b57c8..480438047 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -1,6 +1,107 @@ """Nogo rules.""" -load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule") +load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") + +def _nogo_dump_tool_impl(ctx): + # Extract the Go context. + go_ctx = go_context(ctx) + + # Construct the magic dump command. + # + # Note that in some cases, the input is being fed into the tool via stdin. + # Unfortunately, the Go objdump tool expects to see a seekable file [1], so + # we need the tool to handle this case by creating a temporary file. + # + # [1] https://github.com/golang/go/issues/41051 + env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) + dumper = ctx.actions.declare_file(ctx.label.name) + ctx.actions.write(dumper, "\n".join([ + "#!/bin/bash", + "set -euo pipefail", + "if [[ $# -eq 0 ]]; then", + " T=$(mktemp -u -t libXXXXXX.a)", + " cat /dev/stdin > ${T}", + "else", + " T=$1;", + "fi", + "%s %s tool objdump ${T}" % ( + env_prefix, + go_ctx.go.path, + ), + "if [[ $# -eq 0 ]]; then", + " rm -rf ${T}", + "fi", + "", + ]), is_executable = True) + + # Include the full runfiles. + return [DefaultInfo( + runfiles = ctx.runfiles(files = go_ctx.runfiles.to_list()), + executable = dumper, + )] + +nogo_dump_tool = go_rule( + rule, + implementation = _nogo_dump_tool_impl, +) + +# NogoStdlibInfo is the set of standard library facts. +NogoStdlibInfo = provider( + "information for nogo analysis (standard library facts)", + fields = { + "facts": "serialized standard library facts", + "findings": "package findings (if relevant)", + }, +) + +def _nogo_stdlib_impl(ctx): + # Extract the Go context. + go_ctx = go_context(ctx) + + # Build the standard library facts. + facts = ctx.actions.declare_file(ctx.label.name + ".facts") + findings = ctx.actions.declare_file(ctx.label.name + ".findings") + config = struct( + Srcs = [f.path for f in go_ctx.stdlib_srcs], + GOOS = go_ctx.goos, + GOARCH = go_ctx.goarch, + Tags = go_ctx.tags, + ) + config_file = ctx.actions.declare_file(ctx.label.name + ".cfg") + ctx.actions.write(config_file, config.to_json()) + ctx.actions.run( + inputs = [config_file] + go_ctx.stdlib_srcs, + outputs = [facts, findings], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._dump_tool), + executable = ctx.files._nogo[0], + mnemonic = "GoStandardLibraryAnalysis", + progress_message = "Analyzing Go Standard Library", + arguments = go_ctx.nogo_args + [ + "-dump_tool=%s" % ctx.files._dump_tool[0].path, + "-stdlib=%s" % config_file.path, + "-findings=%s" % findings.path, + "-facts=%s" % facts.path, + ], + ) + + # Return the stdlib facts as output. + return [NogoStdlibInfo( + facts = facts, + findings = findings, + )] + +nogo_stdlib = go_rule( + rule, + implementation = _nogo_stdlib_impl, + attrs = { + "_nogo": attr.label( + default = "//tools/nogo/check:check", + ), + "_dump_tool": attr.label( + default = "//tools/nogo:dump_tool", + ), + }, +) # NogoInfo is the serialized set of package facts for a nogo analysis. # @@ -8,10 +109,14 @@ load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule") # with the source files as input. Note however, that the individual nogo rules # are simply stubs that enter into the shadow dependency tree (the "aspect"). NogoInfo = provider( + "information for nogo analysis", fields = { "facts": "serialized package facts", + "findings": "package findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", + "srcs": "original source files (for go_test support)", + "deps": "original deps (for go_test support)", }, ) @@ -21,15 +126,29 @@ def _nogo_aspect_impl(target, ctx): # All work is done in the shadow properties for go rules. For a proto # library, we simply skip the analysis portion but still need to return a # valid NogoInfo to reference the generated binary. - if ctx.rule.kind == "go_library": + if ctx.rule.kind in ("go_library", "go_binary", "go_test", "go_tool_library"): srcs = ctx.rule.files.srcs - elif ctx.rule.kind == "go_proto_library" or ctx.rule.kind == "go_wrap_cc": + deps = ctx.rule.attr.deps + elif ctx.rule.kind in ("go_proto_library", "go_wrap_cc"): srcs = [] + deps = ctx.rule.attr.deps else: return [NogoInfo()] - # Construct the Go environment from the go_context.env dictionary. - env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_context(ctx).env.items()]) + # Extract the Go context. + go_ctx = go_context(ctx) + + # If we're using the "library" attribute, then we need to aggregate the + # original library sources and dependencies into this target to perform + # proper type analysis. + if ctx.rule.kind == "go_test": + library = go_test_library(ctx.rule) + if library != None: + info = library[NogoInfo] + if hasattr(info, "srcs"): + srcs = srcs + info.srcs + if hasattr(info, "deps"): + deps = deps + info.deps # Start with all target files and srcs as input. inputs = target.files.to_list() + srcs @@ -39,48 +158,30 @@ def _nogo_aspect_impl(target, ctx): # to cleanly allow us redirect stdout to the actual output file. Perhaps # I'm missing something here, but the intermediate script does work. binaries = target.files.to_list() - disasm_file = ctx.actions.declare_file(target.label.name + ".out") - dumper = ctx.actions.declare_file("%s-dumper" % ctx.label.name) - ctx.actions.write(dumper, "\n".join([ - "#!/bin/bash", - "%s %s tool objdump %s > %s\n" % ( - env_prefix, - go_context(ctx).go.path, - [f.path for f in binaries if f.path.endswith(".a")][0], - disasm_file.path, - ), - ]), is_executable = True) - ctx.actions.run( - inputs = binaries, - outputs = [disasm_file], - tools = go_context(ctx).runfiles, - mnemonic = "GoObjdump", - progress_message = "Objdump %s" % target.label, - executable = dumper, - ) - inputs.append(disasm_file) + objfiles = [f for f in binaries if f.path.endswith(".a")] + if len(objfiles) > 0: + # Prefer the .a files for go_library targets. + target_objfile = objfiles[0] + else: + # Use the raw binary for go_binary and go_test targets. + target_objfile = binaries[0] + inputs.append(target_objfile) # Extract the importpath for this package. - importpath = go_importpath(target) - - # The nogo tool requires a configfile serialized in JSON format to do its - # work. This must line up with the nogo.Config fields. - facts = ctx.actions.declare_file(target.label.name + ".facts") - config = struct( - ImportPath = importpath, - GoFiles = [src.path for src in srcs if src.path.endswith(".go")], - NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], - GOOS = go_context(ctx).goos, - GOARCH = go_context(ctx).goarch, - Tags = go_context(ctx).tags, - FactMap = {}, # Constructed below. - ImportMap = {}, # Constructed below. - FactOutput = facts.path, - Objdump = disasm_file.path, - ) + if ctx.rule.kind == "go_test": + # If this is a test, then it will not be imported by anything else. + # We can safely set the importapth to just "test". Note that this + # is necessary if the library also imports the core library (in + # addition to including the sources directly), which happens in + # some complex cases (seccomp_victim). + importpath = "test" + else: + importpath = go_importpath(target) # Collect all info from shadow dependencies. - for dep in ctx.rule.attr.deps: + fact_map = dict() + import_map = dict() + for dep in deps: # There will be no file attribute set for all transitive dependencies # that are not go_library or go_binary rules, such as a proto rules. # This is handled by the ctx.rule.kind check above. @@ -94,45 +195,83 @@ def _nogo_aspect_impl(target, ctx): x_files = [f.path for f in info.binaries if f.path.endswith(".x")] if not len(x_files): x_files = [f.path for f in info.binaries if f.path.endswith(".a")] - config.ImportMap[info.importpath] = x_files[0] - config.FactMap[info.importpath] = info.facts.path + import_map[info.importpath] = x_files[0] + fact_map[info.importpath] = info.facts.path # Ensure the above are available as inputs. inputs.append(info.facts) inputs += info.binaries - # Write the configuration and run the tool. + # Add the standard library facts. + stdlib_facts = ctx.attr._nogo_stdlib[NogoStdlibInfo].facts + inputs.append(stdlib_facts) + + # The nogo tool operates on a configuration serialized in JSON format. + facts = ctx.actions.declare_file(target.label.name + ".facts") + findings = ctx.actions.declare_file(target.label.name + ".findings") + escapes = ctx.actions.declare_file(target.label.name + ".escapes") + config = struct( + ImportPath = importpath, + GoFiles = [src.path for src in srcs if src.path.endswith(".go")], + NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], + GOOS = go_ctx.goos, + GOARCH = go_ctx.goarch, + Tags = go_ctx.tags, + FactMap = fact_map, + ImportMap = import_map, + StdlibFacts = stdlib_facts.path, + ) config_file = ctx.actions.declare_file(target.label.name + ".cfg") ctx.actions.write(config_file, config.to_json()) inputs.append(config_file) - - # Run the nogo tool itself. ctx.actions.run( inputs = inputs, - outputs = [facts], - tools = go_context(ctx).runfiles, + outputs = [facts, findings, escapes], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._dump_tool), executable = ctx.files._nogo[0], mnemonic = "GoStaticAnalysis", progress_message = "Analyzing %s" % target.label, - arguments = ["-config=%s" % config_file.path], + arguments = go_ctx.nogo_args + [ + "-binary=%s" % target_objfile.path, + "-dump_tool=%s" % ctx.files._dump_tool[0].path, + "-package=%s" % config_file.path, + "-findings=%s" % findings.path, + "-facts=%s" % facts.path, + "-escapes=%s" % escapes.path, + ], ) # Return the package facts as output. - return [NogoInfo( - facts = facts, - importpath = importpath, - binaries = binaries, - )] + return [ + NogoInfo( + facts = facts, + findings = findings, + importpath = importpath, + binaries = binaries, + srcs = srcs, + deps = deps, + ), + OutputGroupInfo( + # Expose all findings (should just be a single file). This can be + # used for build analysis of the nogo findings. + nogo_findings = depset([findings]), + # Expose all escape analysis findings (see above). + nogo_escapes = depset([escapes]), + ), + ] nogo_aspect = go_rule( aspect, implementation = _nogo_aspect_impl, - attr_aspects = ["deps"], + attr_aspects = [ + "deps", + "library", + "embed", + ], attrs = { - "_nogo": attr.label( - default = "//tools/nogo/check:check", - allow_single_file = True, - ), + "_nogo": attr.label(default = "//tools/nogo/check:check"), + "_nogo_stdlib": attr.label(default = "//tools/nogo:stdlib"), + "_dump_tool": attr.label(default = "//tools/nogo:dump_tool"), }, ) @@ -144,13 +283,26 @@ def _nogo_test_impl(ctx): # this way so that any test applied is effectively pushed down to all # upstream dependencies through the aspect. inputs = [] + findings = [] runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) runner_content = ["#!/bin/bash"] for dep in ctx.attr.deps: + # Extract the findings. info = dep[NogoInfo] - inputs.append(info.facts) + inputs.append(info.findings) + findings.append(info.findings) + + # Include all source files, transitively. This will make this target + # "directly affected" for the purpose of build analysis. + inputs += info.srcs - # Draw a sweet unicode checkmark with the package name (in green). + # If there are findings, dump them and fail. + runner_content.append("if [[ -s \"%s\" ]]; then cat \"%s\" && exit 1; fi" % ( + info.findings.short_path, + info.findings.short_path, + )) + + # Otherwise, draw a sweet unicode checkmark with the package name (in green). runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath) runner_content.append("exit 0\n") ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) @@ -167,6 +319,10 @@ _nogo_test = rule( test = True, ) -def nogo_test(**kwargs): +def nogo_test(name, **kwargs): tags = kwargs.pop("tags", []) + ["nogo"] - _nogo_test(tags = tags, **kwargs) + _nogo_test( + name = name, + tags = tags, + **kwargs + ) diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go index 57a250501..5c39be630 100644 --- a/tools/nogo/matchers.go +++ b/tools/nogo/matchers.go @@ -16,7 +16,6 @@ package nogo import ( "go/token" - "path/filepath" "regexp" "strings" @@ -44,11 +43,30 @@ type pathRegexps struct { func buildRegexps(prefix string, args ...string) []*regexp.Regexp { result := make([]*regexp.Regexp, 0, len(args)) for _, arg := range args { - result = append(result, regexp.MustCompile(filepath.Join(prefix, arg))) + result = append(result, regexp.MustCompile(prefix+arg)) } return result } +// notPath works around the lack of backtracking. +// +// It is used to construct a regular expression for non-matching components. +func notPath(name string) string { + sb := strings.Builder{} + sb.WriteString("(") + for i := range name { + if i > 0 { + sb.WriteString("|") + } + sb.WriteString(name[:i]) + sb.WriteString("[^") + sb.WriteByte(name[i]) + sb.WriteString("/][^/]*") + } + sb.WriteString(")") + return sb.String() +} + // ShouldReport implements matcher.ShouldReport. func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { fullPos := fs.Position(d.Pos).String() @@ -79,7 +97,7 @@ func externalExcluded(paths ...string) *pathRegexps { // internalMatches returns a path matcher for internal packages. func internalMatches() *pathRegexps { return &pathRegexps{ - expr: buildRegexps(internalPrefix, ".*"), + expr: buildRegexps(internalPrefix, internalDefault), include: true, } } diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index 203cdf688..120fdcff5 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -20,6 +20,7 @@ package nogo import ( "encoding/json" + "errors" "flag" "fmt" "go/ast" @@ -31,50 +32,89 @@ import ( "io/ioutil" "log" "os" + "path" "path/filepath" "reflect" + "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/internal/facts" "golang.org/x/tools/go/gcexportdata" - "gvisor.dev/gvisor/tools/nogo/data" + + // Special case: flags live here and change overall behavior. + "gvisor.dev/gvisor/tools/checkescape" ) -// pkgConfig is serialized as the configuration. +// stdlibConfig is serialized as the configuration. // -// This contains everything required for the analysis. -type pkgConfig struct { - ImportPath string - GoFiles []string - NonGoFiles []string - Tags []string - GOOS string - GOARCH string - ImportMap map[string]string - FactMap map[string]string - FactOutput string - Objdump string +// This contains everything required for stdlib analysis. +type stdlibConfig struct { + Srcs []string + GOOS string + GOARCH string + Tags []string } -// loadFacts finds and loads facts per FactMap. -func (c *pkgConfig) loadFacts(path string) ([]byte, error) { - realPath, ok := c.FactMap[path] - if !ok { - return nil, nil // No facts available. - } +// packageConfig is serialized as the configuration. +// +// This contains everything required for single package analysis. +type packageConfig struct { + ImportPath string + GoFiles []string + NonGoFiles []string + Tags []string + GOOS string + GOARCH string + ImportMap map[string]string + FactMap map[string]string + StdlibFacts string +} - // Read the files file. - data, err := ioutil.ReadFile(realPath) - if err != nil { - return nil, err +// loader is a fact-loader function. +type loader func(string) ([]byte, error) + +// saver is a fact-saver function. +type saver func([]byte) error + +// factLoader returns a function that loads facts. +// +// This resolves all standard library facts and imported package facts up +// front. The returned loader function will never return an error, only +// empty facts. +// +// This is done because all stdlib data is stored together, and we don't want +// to load this data many times over. +func (c *packageConfig) factLoader() (loader, error) { + allFacts := make(map[string][]byte) + if c.StdlibFacts != "" { + data, err := ioutil.ReadFile(c.StdlibFacts) + if err != nil { + return nil, fmt.Errorf("error loading stdlib facts from %q: %w", c.StdlibFacts, err) + } + var stdlibFacts map[string][]byte + if err := json.Unmarshal(data, &stdlibFacts); err != nil { + return nil, fmt.Errorf("error loading stdlib facts: %w", err) + } + for pkg, data := range stdlibFacts { + allFacts[pkg] = data + } + } + for pkg, file := range c.FactMap { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("error loading %q: %w", file, err) + } + allFacts[pkg] = data } - return data, nil + return func(path string) ([]byte, error) { + return allFacts[path], nil + }, nil } // shouldInclude indicates whether the file should be included. // // NOTE: This does only basic parsing of tags. -func (c *pkgConfig) shouldInclude(path string) (bool, error) { +func (c *packageConfig) shouldInclude(path string) (bool, error) { ctx := build.Default ctx.GOOS = c.GOOS ctx.GOARCH = c.GOARCH @@ -88,9 +128,11 @@ func (c *pkgConfig) shouldInclude(path string) (bool, error) { // files, and the facts. Note that this importer implementation will always // pass when a given package is not available. type importer struct { - pkgConfig - fset *token.FileSet - cache map[string]*types.Package + *packageConfig + fset *token.FileSet + cache map[string]*types.Package + lastErr error + callback func(string) error } // Import implements types.Importer.Import. @@ -101,6 +143,17 @@ func (i *importer) Import(path string) (*types.Package, error) { // analyzers are specifically looking for this. return types.Unsafe, nil } + + // Call the internal callback. This is used to resolve loading order + // for the standard library. See checkStdlib. + if i.callback != nil { + if err := i.callback(path); err != nil { + i.lastErr = err + return nil, err + } + } + + // Actually load the data. realPath, ok := i.ImportMap[path] var ( rc io.ReadCloser @@ -109,12 +162,13 @@ func (i *importer) Import(path string) (*types.Package, error) { if !ok { // Not found in the import path. Attempt to find the package // via the standard library. - rc, err = findStdPkg(path, i.GOOS, i.GOARCH) + rc, err = findStdPkg(i.GOOS, i.GOARCH, path) } else { // Open the file. rc, err = os.Open(realPath) } if err != nil { + i.lastErr = err return nil, err } defer rc.Close() @@ -128,6 +182,154 @@ func (i *importer) Import(path string) (*types.Package, error) { return gcexportdata.Read(r, i.fset, i.cache, path) } +// ErrSkip indicates the package should be skipped. +var ErrSkip = errors.New("skipped") + +// checkStdlib checks the standard library. +// +// This constructs a synthetic package configuration for each library in the +// standard library sources, and call checkPackage repeatedly. +// +// Note that not all parts of the source are expected to build. We skip obvious +// test files, and cmd files, which should not be dependencies. +func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]string, []byte, error) { + if len(config.Srcs) == 0 { + return nil, nil, nil + } + + // Ensure all paths are normalized. + for i := 0; i < len(config.Srcs); i++ { + config.Srcs[i] = path.Clean(config.Srcs[i]) + } + + // Calculate the root source directory. This is always a directory + // named 'src', of which we simply take the first we find. This is a + // bit fragile, but works for all currently known Go source + // configurations. + // + // Note that there may be extra files outside of the root source + // directory; we simply ignore those. + rootSrcPrefix := "" + for _, file := range config.Srcs { + const src = "/src/" + i := strings.Index(file, src) + if i == -1 { + // Superfluous file. + continue + } + + // Index of first character after /src/. + i += len(src) + rootSrcPrefix = file[:i] + break + } + + // Aggregate all files by directory. + packages := make(map[string]*packageConfig) + for _, file := range config.Srcs { + if !strings.HasPrefix(file, rootSrcPrefix) { + // Superflouous file. + continue + } + + d := path.Dir(file) + if len(rootSrcPrefix) >= len(d) { + continue // Not a file. + } + pkg := d[len(rootSrcPrefix):] + // Skip cmd packages and obvious test files: see above. + if strings.HasPrefix(pkg, "cmd/") || strings.HasSuffix(file, "_test.go") { + continue + } + c, ok := packages[pkg] + if !ok { + c = &packageConfig{ + ImportPath: pkg, + GOOS: config.GOOS, + GOARCH: config.GOARCH, + Tags: config.Tags, + } + packages[pkg] = c + } + // Add the files appropriately. Note that they will be further + // filtered by architecture and build tags below, so this need + // not be done immediately. + if strings.HasSuffix(file, ".go") { + c.GoFiles = append(c.GoFiles, file) + } else { + c.NonGoFiles = append(c.NonGoFiles, file) + } + } + + // Closure to check a single package. + allFindings := make([]string, 0) + stdlibFacts := make(map[string][]byte) + var checkOne func(pkg string) error // Recursive. + checkOne = func(pkg string) error { + // Is this already done? + if _, ok := stdlibFacts[pkg]; ok { + return nil + } + + // Lookup the configuration. + config, ok := packages[pkg] + if !ok { + return nil // Not known. + } + + // Find the binary package, and provide to objdump. + rc, err := findStdPkg(config.GOOS, config.GOARCH, pkg) + if err != nil { + // If there's no binary for this package, it is likely + // not built with the distribution. That's fine, we can + // just skip analysis. + return nil + } + + // Provide the input. + oldReader := checkescape.Reader + checkescape.Reader = rc // For analysis. + defer func() { + rc.Close() + checkescape.Reader = oldReader // Restore. + }() + + // Run the analysis. + findings, factData, err := checkPackage(config, ac, checkOne) + if err != nil { + // If we can't analyze a package from the standard library, + // then we skip it. It will simply not have any findings. + return nil + } + stdlibFacts[pkg] = factData + allFindings = append(allFindings, findings...) + return nil + } + + // Check all packages. + // + // Note that this may call checkOne recursively, so it's not guaranteed + // to evaluate in the order provided here. We do ensure however, that + // all packages are evaluated. + for pkg := range packages { + checkOne(pkg) + } + + // Sanity check. + if len(stdlibFacts) == 0 { + return nil, nil, fmt.Errorf("no stdlib facts found: misconfiguration?") + } + + // Write out all findings. + factData, err := json.Marshal(stdlibFacts) + if err != nil { + return nil, nil, fmt.Errorf("error saving stdlib facts: %w", err) + } + + // Return all findings. + return allFindings, factData, nil +} + // checkPackage runs all analyzers. // // The implementation was adapted from [1], which was in turn adpated from [2]. @@ -136,11 +338,12 @@ func (i *importer) Import(path string) (*types.Package, error) { // // [1] bazelbuid/rules_go/tools/builders/nogo_main.go // [2] golang.org/x/tools/go/checker/internal/checker -func checkPackage(config pkgConfig) ([]string, error) { +func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, importCallback func(string) error) ([]string, []byte, error) { imp := &importer{ - pkgConfig: config, - fset: token.NewFileSet(), - cache: make(map[string]*types.Package), + packageConfig: config, + fset: token.NewFileSet(), + cache: make(map[string]*types.Package), + callback: importCallback, } // Load all source files. @@ -148,14 +351,14 @@ func checkPackage(config pkgConfig) ([]string, error) { for _, file := range config.GoFiles { include, err := config.shouldInclude(file) if err != nil { - return nil, fmt.Errorf("error evaluating file %q: %v", file, err) + return nil, nil, fmt.Errorf("error evaluating file %q: %v", file, err) } if !include { continue } s, err := parser.ParseFile(imp.fset, file, nil, parser.ParseComments) if err != nil { - return nil, fmt.Errorf("error parsing file %q: %v", file, err) + return nil, nil, fmt.Errorf("error parsing file %q: %v", file, err) } syntax = append(syntax, s) } @@ -172,18 +375,19 @@ func checkPackage(config pkgConfig) ([]string, error) { Selections: make(map[*ast.SelectorExpr]*types.Selection), } types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) - if err != nil { - return nil, fmt.Errorf("error checking types: %v", err) + if err != nil && imp.lastErr != ErrSkip { + return nil, nil, fmt.Errorf("error checking types: %w", err) } // Load all package facts. - facts, err := facts.Decode(types, config.loadFacts) + loader, err := config.factLoader() if err != nil { - return nil, fmt.Errorf("error decoding facts: %v", err) + return nil, nil, fmt.Errorf("error loading facts: %w", err) + } + facts, err := facts.Decode(types, loader) + if err != nil { + return nil, nil, fmt.Errorf("error decoding facts: %w", err) } - - // Set the binary global for use. - data.Objdump = config.Objdump // Register fact types and establish dependencies between analyzers. // The visit closure will execute recursively, and populate results @@ -204,7 +408,7 @@ func checkPackage(config pkgConfig) ([]string, error) { } // Prepare the matcher. - m := analyzerConfig[a] + m := ac[a] report := func(d analysis.Diagnostic) { if m.ShouldReport(d, imp.fset) { diagnostics[a] = append(diagnostics[a], d) @@ -245,18 +449,13 @@ func checkPackage(config pkgConfig) ([]string, error) { return nil // Success. } - // Visit all analysis recursively. - for a, _ := range analyzerConfig { - if err := visit(a); err != nil { - return nil, err // Already has context. + // Visit all analyzers recursively. + for a, _ := range ac { + if imp.lastErr == ErrSkip { + continue // No local analysis. } - } - - // Write the output file. - if config.FactOutput != "" { - factData := facts.Encode() - if err := ioutil.WriteFile(config.FactOutput, factData, 0644); err != nil { - return nil, fmt.Errorf("error: unable to open facts output %q: %v", config.FactOutput, err) + if err := visit(a); err != nil { + return nil, nil, err // Already has context. } } @@ -270,47 +469,104 @@ func checkPackage(config pkgConfig) ([]string, error) { } // Return all findings. - return findings, nil + factData := facts.Encode() + return findings, factData, nil } var ( - configFile = flag.String("config", "", "configuration file (in JSON format)") + packageFile = flag.String("package", "", "package configuration file (in JSON format)") + stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") + findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") + factsOutput = flag.String("facts", "", "output file for facts (optional)") + escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") ) -// Main is the entrypoint; it should be called directly from main. -// -// N.B. This package registers it's own flags. -func Main() { - // Parse all flags. - flag.Parse() - +func loadConfig(file string, config interface{}) interface{} { // Load the configuration. - f, err := os.Open(*configFile) + f, err := os.Open(file) if err != nil { - log.Fatalf("unable to open configuration %q: %v", *configFile, err) + log.Fatalf("unable to open configuration %q: %v", file, err) } defer f.Close() - config := new(pkgConfig) dec := json.NewDecoder(f) dec.DisallowUnknownFields() if err := dec.Decode(config); err != nil { log.Fatalf("unable to decode configuration: %v", err) } + return config +} + +// Main is the entrypoint; it should be called directly from main. +// +// N.B. This package registers it's own flags. +func Main() { + // Parse all flags. + flag.Parse() + + var ( + findings []string + factData []byte + err error + ) + + // Check the configuration. + if *packageFile != "" && *stdlibFile != "" { + log.Fatalf("unable to perform stdlib and package analysis; provide only one!") + } else if *stdlibFile != "" { + // Perform basic analysis. + c := loadConfig(*stdlibFile, new(stdlibConfig)).(*stdlibConfig) + findings, factData, err = checkStdlib(c, analyzerConfig) + } else if *packageFile != "" { + // Perform basic analysis. + c := loadConfig(*packageFile, new(packageConfig)).(*packageConfig) + findings, factData, err = checkPackage(c, analyzerConfig, nil) + // Do we need to do escape analysis? + if *escapesOutput != "" { + escapes, _, err := checkPackage(c, escapesConfig, nil) + if err != nil { + log.Fatalf("error performing escape analysis: %v", err) + } + f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Fatalf("unable to open output %q: %v", *escapesOutput, err) + } + defer f.Close() + for _, escape := range escapes { + fmt.Fprintf(f, "%s\n", escape) + } + } + } else { + log.Fatalf("please provide at least one of package or stdlib!") + } + + // Save facts. + if *factsOutput != "" { + if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { + log.Fatalf("error saving findings to %q: %v", *factsOutput, err) + } + } - // Process the package. - findings, err := checkPackage(*config) + // Open the output file. + var w io.Writer = os.Stdout + if *findingsOutput != "" { + f, err := os.OpenFile(*findingsOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Fatalf("unable to open output %q: %v", *findingsOutput, err) + } + defer f.Close() + w = f + } + + // Handle findings & errors. if err != nil { log.Fatalf("error checking package: %v", err) } - - // No findings? if len(findings) == 0 { - os.Exit(0) + return } - // Print findings and exit with non-zero code. + // Print findings. for _, finding := range findings { - fmt.Fprintf(os.Stdout, "%s\n", finding) + fmt.Fprintf(w, "%s\n", finding) } - os.Exit(1) } diff --git a/tools/nogo/register.go b/tools/nogo/register.go index 62b499661..34b173937 100644 --- a/tools/nogo/register.go +++ b/tools/nogo/register.go @@ -26,6 +26,9 @@ func analyzers() (all []*analysis.Analyzer) { for a, _ := range analyzerConfig { all = append(all, a) } + for a, _ := range escapesConfig { + all = append(all, a) + } return all } diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD new file mode 100644 index 000000000..7ab340b51 --- /dev/null +++ b/tools/nogo/util/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "util", + srcs = ["util.go"], + visibility = ["//visibility:public"], +) diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go new file mode 100644 index 000000000..919fec799 --- /dev/null +++ b/tools/nogo/util/util.go @@ -0,0 +1,85 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package util contains nogo-related utilities. +package util + +import ( + "fmt" + "io/ioutil" + "regexp" + "strconv" + "strings" +) + +// findingRegexp is used to parse findings. +var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`) + +const ( + categoryIndex = 1 + fullPathAndLineIndex = 2 + fullPathIndex = 3 + lineIndex = 4 + messageIndex = 7 +) + +// Finding is a single finding. +type Finding struct { + Category string + Path string + Line int + Message string +} + +// ExtractFindingsFromFile loads findings from a file. +func ExtractFindingsFromFile(filename string) ([]Finding, error) { + content, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return ExtractFindingsFromBytes(content) +} + +// ExtractFindingsFromBytes loads findings from bytes. +func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { + lines := strings.Split(string(content), "\n") + for _, singleLine := range lines { + // Skip blank lines. + singleLine = strings.TrimSpace(singleLine) + if singleLine == "" { + continue + } + m := findingRegexp.FindStringSubmatch(singleLine) + if m == nil { + // We shouldn't see findings like this. + return findings, fmt.Errorf("poorly formated line: %v", singleLine) + } + if m[fullPathAndLineIndex] == "-" { + continue // No source file available. + } + // Cleanup the message. + message := m[messageIndex] + message = strings.Replace(message, " → ", "\n → ", -1) + message = strings.Replace(message, " or ", "\n or ", -1) + // Construct a new annotation. + lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32) + findings = append(findings, Finding{ + Category: m[categoryIndex], + Path: m[fullPathIndex], + Line: int(lineNumber), + Message: message, + }) + } + return findings, nil +} diff --git a/tools/vm/BUILD b/tools/vm/BUILD index f7160c627..d95ca6c63 100644 --- a/tools/vm/BUILD +++ b/tools/vm/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_binary", "gtest") +load("//tools:defs.bzl", "bzl_library", "cc_binary", "gtest") load("//tools/vm:defs.bzl", "vm_image", "vm_test") package( @@ -55,3 +55,9 @@ vm_test( shard_count = 2, targets = [":test"], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/vm/README.md b/tools/vm/README.md index 898c95fca..1e9859e66 100644 --- a/tools/vm/README.md +++ b/tools/vm/README.md @@ -25,6 +25,12 @@ vm_image( These images can be built manually by executing the target. The output on `stdout` will be the image id (in the current project). +For example: + +``` +$ bazel build :ubuntu +``` + Images are always named per the hash of all the hermetic input scripts. This allows images to be memoized quickly and easily. diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl index 0f67cfa92..9af5ad3b4 100644 --- a/tools/vm/defs.bzl +++ b/tools/vm/defs.bzl @@ -60,11 +60,12 @@ def _vm_image_impl(ctx): # Run the builder to generate our output. echo = ctx.actions.declare_file(ctx.label.name) resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( - command = "echo -ne \"#!/bin/bash\\nset -e\\nimage=$(%s)\\necho ${image}\\n\" > %s && chmod 0755 %s" % ( - ctx.files.builder[0].path, - echo.path, - echo.path, - ), + command = "\n".join([ + "set -e", + "image=$(%s)" % ctx.files.builder[0].path, + "echo -ne \"#!/bin/bash\\necho ${image}\\n\" > %s" % echo.path, + "chmod 0755 %s" % echo.path, + ]), tools = [ctx.attr.builder], ) ctx.actions.run_shell( diff --git a/tools/vm/ubuntu1604/30_containerd.sh b/tools/vm/ubuntu1604/30_containerd.sh deleted file mode 100755 index fb3699c12..000000000 --- a/tools/vm/ubuntu1604/30_containerd.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Helper for Go packages below. -install_helper() { - PACKAGE="${1}" - TAG="${2}" - GOPATH="${3}" - - # Clone the repository. - mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ - git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" - - # Checkout and build the repository. - (cd "${GOPATH}"/src/"${PACKAGE}" && \ - git checkout "${TAG}" && \ - GOPATH="${GOPATH}" make && \ - GOPATH="${GOPATH}" make install) -} - -# Install dependencies for the crictl tests. -while true; do - if (apt-get update && apt-get install -y \ - btrfs-tools \ - libseccomp-dev); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install containerd & cri-tools. -GOPATH=$(mktemp -d --tmpdir gopathXXXXX) -install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}" -install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}" - -# Install gvisor-containerd-shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -mv ${shim_path} /usr/local/bin - -# Configure containerd-shim. -declare -r shim_config_path=/etc/containerd -declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml) -mkdir -p ${shim_config_path} -cat > ${shim_config_tmp_path} <<-EOF - runc_shim = "/usr/local/bin/containerd-shim" - -[runsc_config] - debug = "true" - debug-log = "/tmp/runsc-logs/" - strace = "true" - file-access = "shared" -EOF -mv ${shim_config_tmp_path} ${shim_config_path} - -# Configure CNI. -(cd "${GOPATH}" && GOPATH="${GOPATH}" \ - src/github.com/containerd/containerd/script/setup/install-cni) - -# Cleanup the above. -rm -rf "${GOPATH}" -rm -rf "${latest}" -rm -rf "${shim_path}" -rm -rf "${shim_config_tmp_path}" diff --git a/tools/vm/ubuntu1604/25_docker.sh b/tools/vm/ubuntu1604/30_docker.sh index 53d8ca588..d393133e4 100755 --- a/tools/vm/ubuntu1604/25_docker.sh +++ b/tools/vm/ubuntu1604/30_docker.sh @@ -53,13 +53,12 @@ while true; do fi done +# Enable experimental features, for cross-building aarch64 images. # Enable Docker IPv6. cat > /etc/docker/daemon.json <<EOF { + "experimental": true, "fixed-cidr-v6": "2001:db8:1::/64", "ipv6": true } EOF -# Docker's IPv6 support is lacking and does not work the same way as IPv4. We -# can use NAT so containers can reach the outside world. -ip6tables -t nat -A POSTROUTING -s 2001:db8:1::/64 ! -o docker0 -j MASQUERADE diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh index 2974f156c..d3b96c9ad 100755 --- a/tools/vm/ubuntu1604/40_kokoro.sh +++ b/tools/vm/ubuntu1604/40_kokoro.sh @@ -41,7 +41,7 @@ while true; do done # junitparser is used to merge junit xml files. -pip install junitparser +pip install --no-cache-dir junitparser # We need a kbuilder user, which may already exist. useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true diff --git a/website/BUILD b/website/BUILD index 4488cb543..6d92d9103 100644 --- a/website/BUILD +++ b/website/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "pkg_tar") +load("//tools:defs.bzl", "bzl_library", "pkg_tar") load("//website:defs.bzl", "doc", "docs") package(licenses = ["notice"]) @@ -55,9 +55,7 @@ genrule( "docker run -i --user $$(id -u):$$(id -g) " + "-v $$(readlink -m $$T/output/_site):/output " + "gvisor.dev/images/jekyll " + - "/usr/gem/bin/htmlproofer " + - "--disable-external " + - "--check-html " + + "ruby /checks.rb " + "/output && " + "cp $(location //website/cmd/server) $$T/output/server && " + "tar -zcf $@ -C $$T/output . && " + @@ -151,11 +149,15 @@ docs( "//g3doc/user_guide:install", "//g3doc/user_guide:networking", "//g3doc/user_guide:platforms", + "//g3doc/user_guide/containerd:configuration", + "//g3doc/user_guide/containerd:containerd_11", + "//g3doc/user_guide/containerd:quick_start", "//g3doc/user_guide/quick_start:docker", "//g3doc/user_guide/quick_start:kubernetes", "//g3doc/user_guide/quick_start:oci", "//g3doc/user_guide/tutorials:cni", "//g3doc/user_guide/tutorials:docker", + "//g3doc/user_guide/tutorials:docker_compose", "//g3doc/user_guide/tutorials:kubernetes", ], ) @@ -179,3 +181,9 @@ genrule( "rm -rf $$T", tools = ["//website/cmd/syscalldocs"], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/website/_config.yml b/website/_config.yml index b08602970..20fbb3d2d 100644 --- a/website/_config.yml +++ b/website/_config.yml @@ -34,3 +34,6 @@ authors: igudger: name: Ian Gudger email: igudger@google.com + fvoznika: + name: Fabricio Voznika + email: fvoznika@google.com diff --git a/website/_includes/footer.html b/website/_includes/footer.html index 9cc8176f7..c1a373329 100644 --- a/website/_includes/footer.html +++ b/website/_includes/footer.html @@ -8,7 +8,7 @@ <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.13.0/d3.min.js" integrity="sha256-hYXbQJK4qdJiAeDVjjQ9G0D6A0xLnDQ4eJI9dkm7Fpk=" crossorigin="anonymous"></script> {% if site.analytics %} -<script type="application/javascript"> +<script> var doNotTrack = false; if (!doNotTrack) { window.ga=window.ga||function(){(ga.q=ga.q||[]).push(arguments)};ga.l=+new Date; diff --git a/website/_includes/graph.html b/website/_includes/graph.html index f3a999341..ba4cf9840 100644 --- a/website/_includes/graph.html +++ b/website/_includes/graph.html @@ -1,7 +1,7 @@ {::nomarkdown} {% assign fn = include.id | remove: " " | remove: "-" | downcase %} <figure><a href="{{ include.url }}"><svg id="{{ include.id }}" width=500 height=200 onload="render_{{ fn }}()"><title>{{ include.title }}</title></svg></a></figure> -<script type="text/javascript"> +<script> function render_{{ fn }}() { d3.csv("{{ include.url }}", function(d, i, columns) { return d; // Transformed below. diff --git a/website/_includes/header-links.html b/website/_includes/header-links.html index 467bb1e72..4232fdaa5 100644 --- a/website/_includes/header-links.html +++ b/website/_includes/header-links.html @@ -2,7 +2,7 @@ <div class="container"> <div class="navbar-brand"> <a href="/"> - <img src="/assets/logos/logo_solo_on_dark.svg" height="25px" class="d-inline-block align-top" style="margin-right: 10px;" alt="logo"/> + <img src="/assets/logos/logo_solo_on_dark.svg" height="25" class="d-inline-block align-top" style="margin-right: 10px;" alt="logo" /> gVisor </a> </div> diff --git a/website/_layouts/docs.html b/website/_layouts/docs.html index 549305089..0422f9fb0 100644 --- a/website/_layouts/docs.html +++ b/website/_layouts/docs.html @@ -14,30 +14,25 @@ categories: {% for category in layout.categories %} <h3>{{ category }}</h3> <ul class="sidebar-nav"> - {% assign sorted_pages = site.pages | where: 'layout', 'docs' | where: 'category', category | sort: 'weight' | sort: 'subcategory' %} - {% assign subcategory = nil %} - {% for p in sorted_pages %} - {% if p.subcategory != subcategory %} - {% if subcategory != nil %} - </ul> - </li> - {% endif %} - {% assign subcategory = p.subcategory %} - {% if subcategory != nil %} - {% assign ac = "aria-controls" %} - {% assign cid = p.category | remove: " " | downcase %} - {% assign sid = p.subcategory | remove: " " | downcase %} - <li> - <a class="sidebar-nav-heading" data-toggle="collapse" href="#{{ cid }}-{{ sid }}" aria-expanded="false" {{ ac }}="{{ cid }}-{{ sid }}">{{ subcategory }}<span class="caret"></span></a> - <ul class="collapse sidebar-nav sidebar-submenu" id="{{ cid }}-{{ sid }}"> - {% endif %} + {% assign subcats = site.pages | where: 'layout', 'docs' | where: 'category', category | group_by: 'subcategory' | sort: 'name', 'first' %} + {% for subcategory in subcats %} + {% assign sorted_pages = subcategory.items | sort: 'weight', 'last' %} + {% if subcategory.name != "" %} + {% assign ac = "aria-controls" %} + {% assign cid = category | remove: " " | downcase %} + {% assign sid = subcategory.name | remove: " " | downcase %} + <li> + <a class="sidebar-nav-heading" data-toggle="collapse" href="#{{ cid }}-{{ sid }}" aria-expanded="false" {{ ac }}="{{ cid }}-{{ sid }}">{{ subcategory.name }}<span class="caret"></span></a> + <ul class="collapse sidebar-nav sidebar-submenu" id="{{ cid }}-{{ sid }}"> {% endif %} - <li><a href="{{ p.url }}">{{ p.title }}</a></li> - {% endfor %} - {% if subcategory != nil %} - </ul> + {% for p in sorted_pages %} + <li><a href="{{ p.url }}">{{ p.title }}</a></li> + {% endfor %} + {% if subcategory.name != "" %} </li> - {% endif %} + </ul> + {% endif %} + {% endfor %} </ul> {% endfor %} </nav> @@ -47,8 +42,8 @@ categories: <h1>{{ page.title }}</h1> {% if page.editpath %} <p> - <a href="https://github.com/google/gvisor/edit/master/{{page.editpath}}" target="_blank"><i class="fa fa-edit fa-fw"></i> Edit this page</a> - <a href="https://github.com/google/gvisor/issues/new?title={{page.title | url_encode}}" target="_blank"><i class="fab fa-github fa-fw"></i> Create issue</a> + <a href="https://github.com/google/gvisor/edit/master/{{page.editpath}}" target="_blank" rel="noopener"><i class="fa fa-edit fa-fw"></i> Edit this page</a> + <a href="https://github.com/google/gvisor/issues/new?title={{page.title | url_encode}}" target="_blank" rel="noopener"><i class="fab fa-github fa-fw"></i> Create issue</a> </p> {% endif %} <div class="docs-content"> diff --git a/website/_sass/front.scss b/website/_sass/front.scss index 0e4208f3c..f1b060560 100644 --- a/website/_sass/front.scss +++ b/website/_sass/front.scss @@ -1,5 +1,5 @@ .jumbotron { - background-image: url(/assets/images/background.jpg); + background-image: url(/assets/images/background_1080p.jpg); background-position: center; background-repeat: no-repeat; background-size: cover; diff --git a/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png b/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png Binary files differnew file mode 100644 index 000000000..c750f0851 --- /dev/null +++ b/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png diff --git a/website/assets/images/background_1080p.jpg b/website/assets/images/background_1080p.jpg Binary files differnew file mode 100644 index 000000000..d312595a6 --- /dev/null +++ b/website/assets/images/background_1080p.jpg diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md index fbdd511dd..b6cf57a77 100644 --- a/website/blog/2019-11-18-security-basics.md +++ b/website/blog/2019-11-18-security-basics.md @@ -44,10 +44,10 @@ into it in the next section! # Design Principles -gVisor was designed with some -[common secure design principles](https://www.owasp.org/index.php/Security_by_Design_Principles) -in mind: Defense-in-Depth, Principle of Least-Privilege, Attack Surface -Reduction and Secure-by-Default[^1]. +gVisor was designed with some common +[secure design](https://en.wikipedia.org/wiki/Secure_by_design) principles in +mind: Defense-in-Depth, Principle of Least-Privilege, Attack Surface Reduction +and Secure-by-Default[^1]. In general, Design Principles outline good engineering practices, but in the case of security, they also can be thought of as a set of tactics. In a @@ -188,7 +188,7 @@ for direct access to some files. And most files will be remotely accessed through the Gofers, in which case no FDs are donated to the Sentry. The Sentry itself is only allowed access to specific -[whitelisted syscalls](https://github.com/google/gvisor/blob/master/runsc/boot/config.go). +[whitelisted syscalls](https://github.com/google/gvisor/blob/master/runsc/config/config.go). Without networking, the Sentry needs 53 host syscalls in order to function, and with networking, it uses an additional 15[^8]. By limiting the whitelist to only these needed syscalls, we radically reduce the amount of host OS attack surface. @@ -279,19 +279,28 @@ weaknesses of each gVisor component. We will also use it to introduce Google's Vulnerability Reward Program[^14], and other ways the community can contribute to help make gVisor safe, fast and stable. +<br> +<br> -## Notes +-------------------------------------------------------------------------------- -[^1]: [https://www.owasp.org/index.php/Security_by_Design_Principles](https://www.owasp.org/index.php/Security_by_Design_Principles) +[^1]: [https://en.wikipedia.org/wiki/Secure_by_design](https://en.wikipedia.org/wiki/Secure_by_design) [^2]: [https://gvisor.dev/docs/architecture_guide](https://gvisor.dev/docs/architecture_guide/) [^3]: [https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/linux64_amd64.go](https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/syscalls.go) -[^4]: Internally that is, it doesn't call to the Host OS to implement them, in - fact that is explicitly disallowed, more on that in the future. + +<!-- mdformat off(mdformat formats this into multiple lines) --> +[^4]: Internally that is, it doesn't call to the Host OS to implement them, in fact that is explicitly disallowed, more on that in the future. +<!-- mdformat on --> + [^5]: [https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345](https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345) [^6]: [https://github.com/google/gvisor/tree/master/runsc/boot/filter](https://github.com/google/gvisor/tree/master/runsc/boot/filter) [^7]: [https://en.wikipedia.org/wiki/Dirty_COW](https://en.wikipedia.org/wiki/Dirty_COW) [^8]: [https://github.com/google/gvisor/blob/master/runsc/boot/config.go](https://github.com/google/gvisor/blob/master/runsc/boot/config.go) -[^9]: [https://en.wikipedia.org/wiki/9P_(protocol)](https://en.wikipedia.org/wiki/9P_\(protocol\)) + +<!-- mdformat off(mdformat breaks this url by escaping the parenthesis) --> +[^9]: [https://en.wikipedia.org/wiki/9P_(protocol)](https://en.wikipedia.org/wiki/9P_(protocol)) +<!-- mdformat on --> + [^10]: [https://gvisor.dev/docs/user_guide/networking/#network-passthrough](https://gvisor.dev/docs/user_guide/networking/#network-passthrough) [^11]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390) [^12]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182) diff --git a/website/blog/2020-04-02-networking-security.md b/website/blog/2020-04-02-networking-security.md index 5a5e38fd7..f3ce02d11 100644 --- a/website/blog/2020-04-02-networking-security.md +++ b/website/blog/2020-04-02-networking-security.md @@ -108,7 +108,7 @@ re-architecting the TCP implementation to use fewer goroutines. Performance today is good enough for most applications and we are making steady improvements. For example, since May of 2019, we have improved the Netstack runsc -[iperf3 download benchmark](https://github.com/google/gvisor/blob/master/benchmarks/suites/network.py) +[iperf3 download benchmark](https://github.com/google/gvisor/tree/master/test/benchmarks/network) score by roughly 15% and upload score by around 10,000X. Current numbers are about 17 Gbps download and about 8 Gbps upload versus about 42 Gbps and 43 Gbps for native (Linux) respectively. diff --git a/website/blog/2020-09-18-containing-a-real-vulnerability.md b/website/blog/2020-09-18-containing-a-real-vulnerability.md new file mode 100644 index 000000000..c1b06a996 --- /dev/null +++ b/website/blog/2020-09-18-containing-a-real-vulnerability.md @@ -0,0 +1,223 @@ +# Containing a Real Vulnerability + +In the previous two posts we talked about gVisor's +[security design principles](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/) +as well as how those are applied in the +[context of networking](https://gvisor.dev/blog/2020/04/02/gvisor-networking-security/). +Recently, a new container escape vulnerability +([CVE-2020-14386](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14386)) +was announced that ties these topics well together. gVisor is +[not vulnerable](https://seclists.org/oss-sec/2020/q3/168) to this specific +issue, but it provides an interesting case study to continue our exploration of +gVisor's security. While gVisor is not immune to vulnerabilities, +[we take several steps](https://gvisor.dev/security/) to minimize the impact and +remediate if a vulnerability is found. + +## Escaping the Container + +First, let’s describe how the discovered vulnerability works. There are numerous +ways one can send and receive bytes over the network with Linux. One of the most +performant ways is to use a ring buffer, which is a memory region shared by the +application and the kernel. These rings are created by calling +[setsockopt(2)](https://man7.org/linux/man-pages/man2/setsockopt.2.html) with +[`PACKET_RX_RING`](https://man7.org/linux/man-pages/man7/packet.7.html) for +receiving and +[`PACKET_TX_RING`](https://man7.org/linux/man-pages/man7/packet.7.html) for +sending packets. + +The vulnerability is in the code that reads packets when `PACKET_RX_RING` is +enabled. There is another option +([`PACKET_RESERVE`](https://man7.org/linux/man-pages/man7/packet.7.html)) that +asks the kernel to leave some space in the ring buffer before each packet for +anything the application needs, e.g. control structures. When a packet is +received, the kernel calculates where to copy the packet to, taking the amount +reserved before each packet into consideration. If the amount reserved is large, +the kernel performed an incorrect calculation which could cause an overflow +leading to an out-of-bounds write of up to 10 bytes, controlled by the attacker. +The data in the write is easily controlled using the loopback to send a crafted +packet and receiving it using a `PACKET_RX_RING` with a carefully selected +`PACKET_RESERVE` size. + +```c +static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ +// ... + if (sk->sk_type == SOCK_DGRAM) { + macoff = netoff = TPACKET_ALIGN(po->tp_hdrlen) + 16 + + po->tp_reserve; + } else { + unsigned int maclen = skb_network_offset(skb); + // tp_reserve is unsigned int, netoff is unsigned short. Addition can overflow netoff + netoff = TPACKET_ALIGN(po->tp_hdrlen + + (maclen < 16 ? 16 : maclen)) + + po->tp_reserve; + if (po->has_vnet_hdr) { + netoff += sizeof(struct virtio_net_hdr); + do_vnet = true; + } + // Attacker controls netoff and can make macoff be smaller than sizeof(struct virtio_net_hdr) + macoff = netoff - maclen; + } +// ... + // "macoff - sizeof(struct virtio_net_hdr)" can be negative, resulting in a pointer before h.raw + if (do_vnet && + virtio_net_hdr_from_skb(skb, h.raw + macoff - + sizeof(struct virtio_net_hdr), + vio_le(), true, 0)) { +// ... +``` + +The [`CAP_NET_RAW`](https://man7.org/linux/man-pages/man7/capabilities.7.html) +capability is required to create the socket above. However, in order to support +common debugging tools like `ping` and `tcpdump`, Docker containers, including +those created for Kubernetes, are given `CAP_NET_RAW` by default and thus may be +able to trigger this vulnerability to elevate privileges and escape the +container. + +Next, we are going to explore why this vulnerability doesn’t work in gVisor, and +how gVisor could prevent the escape even if a similar vulnerability existed +inside gVisor’s kernel. + +## Default Protections + +gVisor does not implement `PACKET_RX_RING`, but **does** support raw sockets +which are required for `PACKET_RX_RING`. Raw sockets are a controversial feature +to support in a sandbox environment. While it allows great customizations for +essential tools like `ping`, it may allow packets to be written to the network +without any validation. In general, allowing an untrusted application to write +crafted packets to the network is a questionable idea and a historical source of +vulnerabilities. With that in mind, if `CAP_NET_RAW` is enabled by default, it +would not be _secure by default_ to run untrusted applications. + +After multiple discussions when raw sockets were first implemented, we decided +to disable raw sockets by default, **even if `CAP_NET_RAW` is given to the +application**. Instead, enabling raw sockets in gVisor requires the admin to set +`--net-raw` flag to runsc when configuring the runtime, in addition to requiring +the `CAP_NET_RAW` capability in the application. It comes at the expense that +some tools may not work out of the box, but as part of our +[secure-by-default](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#secure-by-default) +principle, we felt that it was important for the “less secure” configuration to +be explicit. + +Since this bug was due to an overflow in the specific Linux implementation of +the packet ring, gVisor's raw socket implementation is not affected. However, if +there were a vulnerability in gVisor, containers would not be allowed to exploit +it by default. + +As an alternative way to implement this same constraint, Kubernetes allows +[admission controllers](https://kubernetes.io/docs/reference/access-authn-authz/admission-controllers/) +to be configured to customize requests. Cloud providers can use this to +implement more stringent policies. For example, GKE implements an admission +controller for gVisor that +[removes `CAP_NET_RAW` from gVisor pods](https://cloud.google.com/kubernetes-engine/docs/concepts/sandbox-pods#capabilities) +unless it has been explicitly set in the pod spec. + +## Isolated Kernel + +gVisor has its own application kernel, called the Sentry, that is distinct from +the host kernel. Just like what you would expect from a kernel, gVisor has a +memory management subsystem, virtual file system, and a full network stack. The +host network is only used as a transport to carry packets in and out the +sandbox[^1]. The loopback interface which is used in the exploit stays +completely inside the sandbox, never reaching the host. + +Therefore, even if the Sentry was vulnerable to the attack, there would be two +factors that would prevent a container escape from happening. First, the +vulnerability would be limited to the Sentry, and the attacker would compromise +only the application kernel, bound by a restricted set of +[seccomp](https://en.wikipedia.org/wiki/Seccomp) filters, discussed more in +depth below. Second, the Sentry is a distinct implementation of the API, written +in Go, which provides bounds checking that would have likely prevented access +past the bounds of the shared region (e.g. see +[aio](https://cs.opensource.google/gvisor/gvisor/+/master:pkg/sentry/syscalls/linux/vfs2/aio.go;l=210;drc=a11061d78a58ed75b10606d1a770b035ed944b66?q=file:aio&ss=gvisor%2Fgvisor) +or +[kcov](https://cs.opensource.google/gvisor/gvisor/+/master:pkg/sentry/kernel/kcov.go;l=272?q=file:kcov&ss=gvisor%2Fgvisor), +which have similar shared regions). + +Here, Kubernetes warrants slightly more explanation. gVisor makes pods the unit +of isolation and a pod can run multiple containers. In other words, each pod is +a gVisor instance, and each container is a set of processes running inside +gVisor, isolated via Sentry-internal namespaces like regular containers inside a +pod. If there were a vulnerability in gVisor, the privilege escalation would +allow a container inside the pod to break out to other **containers inside the +same pod**, but the container still **cannot break out of the pod**. + +## Defense in Depth + +gVisor follows a +[common security principle used at Google](https://cloud.google.com/security/infrastructure/design/resources/google_infrastructure_whitepaper_fa.pdf) +that the system should have two layers of protection, and those layers should +require different compromises to be broken. We apply this principle by assuming +that the Sentry (first layer of defense) +[will be compromised and should not be trusted](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#defense-in-depth). +In order to protect the host kernel from a compromised Sentry, we wrap it around +many security and isolations features to ensure only the minimal set of +functionality from the host kernel is exposed. + + + +First, the sandbox runs inside a cgroup that can limit and throttle host +resources being used. Second, the sandbox joins empty namespaces, including user +and mount, to further isolate from the host. Next, it changes the process root +to a read-only directory that contains only `/proc` and nothing else. Then, it +executes with the unprivileged user/group +[`nobody`](https://en.wikipedia.org/wiki/Nobody_\(username\)) with all +capabilities stripped. Last and most importantly, a seccomp filter is added to +tightly restrict what parts of the Linux syscall surface that gVisor is allowed +to access. The allowed host surface is a far smaller set of syscalls than the +Sentry implements for applications to use. Not only restricting the syscall +being called, but also checking that arguments to these syscalls are within the +expected set. Dangerous syscalls like <code>execve(2)</code>, +<code>open(2)</code>, and <code>socket(2)</code> are prohibited, thus an +attacker isn’t able to execute binaries or acquire new resources on the host. + +if there were a vulnerability in gVisor that allowed an attacker to execute code +inside the Sentry, the attacker still has extremely limited privileges on the +host. In fact, a compromised Sentry is much more restricted than a +non-compromised regular container. For CVE-2020-14386 in particular, the attack +would be blocked by more than one security layer: non-privileged user, no +capability, and seccomp filters. + +Although the surface is drastically reduced, there is still a chance that there +is a vulnerability in one of the allowed syscalls. That’s why it’s important to +keep the surface small and carefully consider what syscalls are allowed. You can +find the full set of allowed syscalls +[here](https://cs.opensource.google/gvisor/gvisor/+/master:runsc/boot/filter/). + +Another possible attack vector is resources that are present in the Sentry, like +open file descriptors. The Sentry has file descriptors that an attacker could +potentially use, such as log files, platform files (e.g. `/dev/kvm`), an RPC +endpoint that allows external communication with the Sentry, and a Netstack +endpoint that connects the sandbox to the network. The Netstack endpoint in +particular is a concern because it gives direct access to the network. It’s an +`AF_PACKET` socket that allows arbitrary L2 packets to be written to the +network. In the normal case, Netstack assembles packets that go out the network, +giving the container control over only the payload. But if the Sentry is +compromised, an attacker can craft packets to the network. In many ways this is +similar to anyone sending random packets over the internet, but still this is a +place where the host kernel surface exposed is larger than we would like it to +be. + +## Conclusion + +Security comes with many tradeoffs that are often hard to make, such as the +decision to disable raw sockets by default. However, these tradeoffs have served +us well, and we've found them to have paid off over time. CVE-2020-14386 offers +great insight into how multiple layers of protection can be effective against +such an attack. + +We cannot guarantee that a container escape will never happen in gVisor, but we +do our best to make it as hard as we possibly can. + +If you have not tried gVisor yet, it’s easier than you think. Just follow the +steps [here](https://gvisor.dev/docs/user_guide/install/). +<br> +<br> + +-------------------------------------------------------------------------------- + +[^1]: Those packets are eventually handled by the host, as it needs to route + them to local containers or send them out the NIC. The packet will be + handled by many switches, routers, proxies, servers, etc. along the way, + which may be subject to their own vulnerabilities. diff --git a/website/blog/BUILD b/website/blog/BUILD index 01c1f5a6e..865e403da 100644 --- a/website/blog/BUILD +++ b/website/blog/BUILD @@ -28,6 +28,16 @@ doc( permalink = "/blog/2020/04/02/gvisor-networking-security/", ) +doc( + name = "containing_a_real_vulnerability", + src = "2020-09-18-containing-a-real-vulnerability.md", + authors = [ + "fvoznika", + ], + layout = "post", + permalink = "/blog/2020/09/18/containing-a-real-vulnerability/", +) + docs( name = "posts", deps = [ diff --git a/website/css/main.scss b/website/css/main.scss index 06106833f..4b3b7b500 100644 --- a/website/css/main.scss +++ b/website/css/main.scss @@ -1,5 +1,10 @@ -@import 'style.scss'; -@import 'front.scss'; -@import 'navbar.scss'; -@import 'sidebar.scss'; -@import 'footer.scss'; +// The main style sheet for gvisor.dev + +// NOTE: Do not include file extensions to import .sass and .css files seamlessly. +@import "style"; +@import "front"; +@import "navbar"; +@import "sidebar"; +@import "footer"; +// syntax is generated by rougify. +@import "syntax"; diff --git a/website/defs.bzl b/website/defs.bzl index ead6a3067..f52946c15 100644 --- a/website/defs.bzl +++ b/website/defs.bzl @@ -1,5 +1,7 @@ """Wrappers for website documentation.""" +load("//tools:defs.bzl", "short_path") + # DocInfo is a provider which simple adds sufficient metadata to the source # files (and additional data files) so that a jeyll header can be constructed # dynamically. This is done the via BUILD system so that the plain @@ -29,7 +31,7 @@ def _doc_impl(ctx): category = ctx.attr.category, subcategory = ctx.attr.subcategory, weight = ctx.attr.weight, - editpath = ctx.files.src[0].short_path, + editpath = short_path(ctx.files.src[0].short_path), authors = ctx.attr.authors, ), ] diff --git a/website/index.md b/website/index.md index 84f877d49..c6cd477c2 100644 --- a/website/index.md +++ b/website/index.md @@ -5,7 +5,7 @@ <div class="col-md-6"> <p>gVisor is an <b>application kernel</b> for <b>containers</b> that provides efficient defense-in-depth anywhere.</p> <p style="margin-top: 20px;"> - <a class="btn" href="/docs/user_guide/quick_start/docker/">Quick start <i class="fas fa-arrow-alt-circle-right ml-2"></i></a> + <a class="btn" href="/docs/user_guide/install/">Get started <i class="fas fa-arrow-alt-circle-right ml-2"></i></a> <a class="btn" href="/docs/">Learn More <i class="fas fa-arrow-alt-circle-right ml-2"></i></a> </p> </div> diff --git a/website/performance/README.md b/website/performance/README.md index 0dbfd2f02..1758fc608 100644 --- a/website/performance/README.md +++ b/website/performance/README.md @@ -1,9 +1,10 @@ # Performance data -This directory holds the CSVs generated by the -[benchmark-tools][benchmark-tools] repository. +This directory holds the CSVs generated by the now removed benchmark-tools +repository. The new functionally equivalent +[benchmark-tools is available.][benchmark-tools] In the future, these will be automatically posted to a cloud storage bucket and loaded dynamically. At that point, this directory will be removed. -[benchmark-tools]: https://github.com/google/gvisor/tree/master/benchmarks +[benchmark-tools]: https://github.com/google/gvisor/tree/master/test/benchmarks |